From e22cb49c16c4b708be97f50f61d6bd7d9934d105 Mon Sep 17 00:00:00 2001 From: Mantissagithub Date: Wed, 3 Jun 2026 02:37:34 +0530 Subject: [PATCH 1/4] fix(ppo): reduce fused logprob memory peak --- .github/workflows/pr-test.yml | 2 +- .github/workflows/pr-test.yml.j2 | 1 + examples/retool/retool_qwen3_4b_rl.sh | 4 +- slime/backends/megatron_utils/model.py | 26 +++ slime/utils/ppo_utils.py | 137 +++++++++----- tests/test_logprob_entropy_fused.py | 235 +++++++++++++++++++++++++ tools/repro_1951.py | 166 +++++++++++++++++ 7 files changed, 527 insertions(+), 44 deletions(-) create mode 100644 tests/test_logprob_entropy_fused.py create mode 100644 tools/repro_1951.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f329d10c2b..cf7d094ca6 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -454,7 +454,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] + info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_entropy_fused.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 0122182bdd..4afd6bf9da 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -74,6 +74,7 @@ {'test_file': 'test_metric_report.py', 'num_gpus': 0}, {'test_file': 'test_metric_report_dist.py', 'num_gpus': 0}, {'test_file': 'test_loss_cp_invariance.py', 'num_gpus': 0}, + {'test_file': 'test_logprob_entropy_fused.py', 'num_gpus': 0}, {'test_file': 'test_value_temperature.py', 'num_gpus': 0}, {'test_file': 'test_rm_f1.py', 'num_gpus': 0}, {'test_file': 'test_rm_gpqa.py', 'num_gpus': 0}, diff --git a/examples/retool/retool_qwen3_4b_rl.sh b/examples/retool/retool_qwen3_4b_rl.sh index 32a837f394..01b4e84bd4 100644 --- a/examples/retool/retool_qwen3_4b_rl.sh +++ b/examples/retool/retool_qwen3_4b_rl.sh @@ -75,6 +75,8 @@ PERF_ARGS=( # --micro-batch-size 1 --use-dynamic-batch-size --max-tokens-per-gpu 9216 + # Bound the fused cross-entropy [tokens, vocab] transient that can OOM on long retool traces. + --log-probs-chunk-size 1024 ) GRPO_ARGS=( @@ -153,4 +155,4 @@ ray job submit --address="http://127.0.0.1:8265" \ ${EVAL_ARGS[@]} \ ${SGLANG_ARGS[@]} \ ${MISC_ARGS[@]} \ - ${CUSTOM_ARGS[@]} \ No newline at end of file + ${CUSTOM_ARGS[@]} diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 96b20c6a21..50a6314453 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -39,6 +39,23 @@ logger = logging.getLogger(__name__) +def _mem_probe_enabled() -> bool: + return os.environ.get("SLIME_MEM_PROBE", "0") == "1" and torch.cuda.is_available() + + +def _log_train_step_mem_probe(rollout_id: int, step_id: int, start_allocated: int) -> None: + max_allocated = torch.cuda.max_memory_allocated() + logger.info( + "SLIME_MEM_PROBE train_one_step rollout_id=%s step_id=%s " + "allocated_start=%s allocated_peak=%s allocated_peak_delta=%s", + rollout_id, + step_id, + start_allocated, + max_allocated, + max_allocated - start_allocated, + ) + + def _disable_tqdm_for_non_main_rank() -> bool: return not ( mpu.get_data_parallel_rank(with_context_parallel=True) == 0 @@ -455,6 +472,11 @@ def train_one_step( and gradient norm for logging. """ args = get_args() + mem_probe = _mem_probe_enabled() + mem_probe_start_allocated = 0 + if mem_probe: + torch.cuda.reset_peak_memory_stats() + mem_probe_start_allocated = torch.cuda.memory_allocated() # Set grad to zero. for model_chunk in model: @@ -601,7 +623,11 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p cp_size=mpu.get_context_parallel_world_size(), dp_with_cp_group=mpu.get_data_parallel_group(with_context_parallel=True), ) + if mem_probe: + _log_train_step_mem_probe(rollout_id, step_id, mem_probe_start_allocated) return loss_reduced, grad_norm + if mem_probe: + _log_train_step_mem_probe(rollout_id, step_id, mem_probe_start_allocated) return {}, grad_norm diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 2a858e7a3f..a5d8402691 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -158,44 +158,78 @@ def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: return -fused_vocab_parallel_cross_entropy(logits, tokens, process_group) -# from https://github.com/volcengine/verl/blob/0bdf7f469854815177e73dcfe9e420836c952e6e/verl/utils/megatron/tensor_parallel.py#L99 -class _VocabParallelEntropy(torch.autograd.Function): +class _VocabParallelLogProbsAndEntropy(torch.autograd.Function): @staticmethod - def forward(ctx, vocab_parallel_logits: torch.Tensor, process_group: dist.ProcessGroup) -> torch.Tensor: + def forward(ctx, vocab_parallel_logits: torch.Tensor, target: torch.Tensor, process_group): + from megatron.core.tensor_parallel.utils import VocabUtility - @torch.compile(dynamic=True) - def mul_reduce(a, b): - return (a * b).sum(dim=-1, keepdim=True) - - logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values + logits_max = vocab_parallel_logits.max(dim=-1).values dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) - normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max - normalized_exp_logits = normalized_vocab_parallel_logits.exp_() - normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True) - dist.all_reduce(normalized_sum_exp_logits, group=process_group) - softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits) - sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits) - dist.all_reduce(sum_softmax_times_logits, group=process_group) - entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits - ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits) - return entropy.squeeze(dim=-1) + + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size(-1) + vocab_start_index, vocab_end_index = get_vocab_range( + partition_vocab_size, process_group.rank(), process_group.size() + ) + + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size(0), device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + + torch.exp(vocab_parallel_logits, out=vocab_parallel_logits) + sum_exp_logits = vocab_parallel_logits.sum(dim=-1) + dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=process_group) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + + log_sum_exp = sum_exp_logits.log() + log_prob = predicted_logits - log_sum_exp + softmax = vocab_parallel_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + local_entropy = torch.zeros_like(log_sum_exp) + for softmax_chunk in softmax.chunk(max(1, softmax.size(-1) // 8192), dim=-1): + local_entropy -= torch.xlogy(softmax_chunk, softmax_chunk).sum(dim=-1) + dist.all_reduce(local_entropy, op=dist.ReduceOp.SUM, group=process_group) + + ctx.save_for_backward(softmax, target_mask, masked_target_1d, local_entropy) + return log_prob, local_entropy.squeeze(dim=-1) @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors - # reuse softmax_logits as grad - vocab_parallel_logits.sub_(sum_softmax_times_logits) - softmax_logits.mul_(vocab_parallel_logits) - softmax_logits.mul_(grad_output.unsqueeze(dim=-1)) - # recover vocab_parallel_logits - vocab_parallel_logits.add_(sum_softmax_times_logits) - softmax_logits.mul_(-1) - return softmax_logits, None + def backward(ctx, grad_log_prob: torch.Tensor, grad_entropy: torch.Tensor): + softmax, target_mask, masked_target_1d, entropy = ctx.saved_tensors + partition_vocab_size = softmax.size(-1) + if grad_entropy is None: + grad_entropy = torch.zeros_like(entropy.squeeze(dim=-1)) + if grad_log_prob is None: + grad_log_prob = torch.zeros_like(entropy) + + log_softmax = torch.where(softmax > 0, softmax.log(), torch.zeros_like(softmax)) + entropy_grad = softmax * ( + grad_entropy.unsqueeze(dim=-1).unsqueeze(dim=-1) * (-entropy.unsqueeze(dim=-1) - log_softmax) + ) + grad_input = softmax * -grad_log_prob.unsqueeze(dim=-1) + grad_input_2d = grad_input.view(-1, partition_vocab_size) + arange_1d = torch.arange(start=0, end=grad_input_2d.size(0), device=grad_input_2d.device) + softmax_update = 1.0 - target_mask.view(-1).float() + grad_input_2d[arange_1d, masked_target_1d] += grad_log_prob.view(-1) * softmax_update -def compute_entropy_from_logits(logits: torch.Tensor, process_group) -> torch.Tensor: - return _VocabParallelEntropy.apply(logits, process_group) + return grad_input.to(torch.bfloat16) + entropy_grad, None, None + + +def compute_log_probs_and_entropy(logits: torch.Tensor, tokens: torch.Tensor, process_group): + logits = logits.unsqueeze(1) + tokens = tokens.unsqueeze(1) + return _VocabParallelLogProbsAndEntropy.apply(logits, tokens, process_group) def get_grpo_returns( @@ -646,6 +680,18 @@ def chunked_gae( return advantages, returns +def _clone_if_grad_tracked(logits: torch.Tensor) -> torch.Tensor: + # Megatron-LM's fused CE mutates its float32 input in place (subtract-max, + # exp(out=...), div_; see INVESTIGATION.md). That is safe to hand over directly + # only when autograd will not observe the mutation. When grad is tracked, an + # in-place write on a (view of a) grad-requiring tensor corrupts the graph and + # raises (issue #1951). Clone exactly when grad is tracked; otherwise pass the + # tensor through so the no-grad ref/old-logprob path keeps its peak-memory win. + if torch.is_grad_enabled() and logits.requires_grad: + return logits.clone() + return logits + + def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1): logits = logits.contiguous() entropy = None @@ -657,22 +703,29 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool if with_entropy: entropys = [] - for logits_chunk in logits_chunks: - entropy_input = logits_chunk.clone() - entropys.append(compute_entropy_from_logits(entropy_input, tp_group)) + log_probs = [] + for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + # The fused helper computes log-probs and entropy from one tensor, + # replacing the two separate destructive passes (and two clones). + log_prob, entropy_chunk = compute_log_probs_and_entropy( + _clone_if_grad_tracked(logits_chunk), tokens_chunk, tp_group + ) + log_probs.append(log_prob) + entropys.append(entropy_chunk) entropy = torch.cat(entropys, dim=0) - - log_probs = [] - for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): - log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group) - log_probs.append(log_prob) + else: + log_probs = [] + for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + log_prob = compute_log_probs(_clone_if_grad_tracked(logits_chunk), tokens_chunk, tp_group) + log_probs.append(log_prob) log_prob = torch.cat(log_probs, dim=0) else: if with_entropy: - entropy_input = logits.clone() - entropy = compute_entropy_from_logits(entropy_input, tp_group) - - log_prob = compute_log_probs(logits.clone(), tokens, tp_group) + # The fused helper computes log-probs and entropy from one tensor, + # replacing the two separate destructive passes (and two clones). + log_prob, entropy = compute_log_probs_and_entropy(_clone_if_grad_tracked(logits), tokens, tp_group) + else: + log_prob = compute_log_probs(_clone_if_grad_tracked(logits), tokens, tp_group) else: log_prob = logits.new_zeros((0,)) if with_entropy: diff --git a/tests/test_logprob_entropy_fused.py b/tests/test_logprob_entropy_fused.py new file mode 100644 index 0000000000..4ac0ff7e4f --- /dev/null +++ b/tests/test_logprob_entropy_fused.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import sys +import types +from contextlib import contextmanager +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from _cp_dist_helpers import free_port + + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +NUM_GPUS = 0 + + +class _FakeSingleRankGroup: + def rank(self) -> int: + return 0 + + def size(self) -> int: + return 1 + + +def _install_megatron_stubs() -> None: + megatron = sys.modules.setdefault("megatron", types.ModuleType("megatron")) + core = sys.modules.setdefault("megatron.core", types.ModuleType("megatron.core")) + fusions = sys.modules.setdefault("megatron.core.fusions", types.ModuleType("megatron.core.fusions")) + fused = types.ModuleType("megatron.core.fusions.fused_cross_entropy") + tensor_parallel = sys.modules.setdefault( + "megatron.core.tensor_parallel", types.ModuleType("megatron.core.tensor_parallel") + ) + utils = types.ModuleType("megatron.core.tensor_parallel.utils") + + class VocabUtility: + @staticmethod + def vocab_range_from_per_partition_vocab_size(partition_vocab_size: int, rank: int, world_size: int): + assert world_size > 0 + assert 0 <= rank < world_size + start = rank * partition_vocab_size + return start, start + partition_vocab_size + + class _MockVocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits: torch.Tensor, target: torch.Tensor, process_group): + del process_group + logits_max = logits.max(dim=-1, keepdim=True).values + logits.sub_(logits_max) + predicted_logits = logits.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) + torch.exp(logits, out=logits) + sum_exp_logits = logits.sum(dim=-1) + logits.div_(sum_exp_logits.unsqueeze(-1)) + ctx.save_for_backward(logits, target) + return sum_exp_logits.log() - predicted_logits + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + softmax, target = ctx.saved_tensors + grad_input = softmax.clone() + grad_input.scatter_add_( + dim=-1, + index=target.unsqueeze(-1), + src=-torch.ones_like(target, dtype=grad_input.dtype).unsqueeze(-1), + ) + grad_input.mul_(grad_output.unsqueeze(-1)) + return grad_input.to(torch.bfloat16), None, None + + def fused_vocab_parallel_cross_entropy(logits: torch.Tensor, target: torch.Tensor, process_group): + return _MockVocabParallelCrossEntropy.apply(logits, target, process_group) + + fused.fused_vocab_parallel_cross_entropy = fused_vocab_parallel_cross_entropy + utils.VocabUtility = VocabUtility + fusions.fused_cross_entropy = fused + tensor_parallel.utils = utils + core.fusions = fusions + core.tensor_parallel = tensor_parallel + megatron.core = core + sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused + sys.modules["megatron.core.tensor_parallel.utils"] = utils + + +@contextmanager +def _single_rank_all_reduce(): + original_all_reduce = dist.all_reduce + + def all_reduce(tensor, op=None, group=None, async_op=False): + del tensor, op, group + if async_op: + raise NotImplementedError("async all_reduce is not needed by this test") + return None + + dist.all_reduce = all_reduce + try: + yield + finally: + dist.all_reduce = original_all_reduce + + +def _naive_log_probs_and_entropy(logits: torch.Tensor, tokens: torch.Tensor): + log_softmax = torch.log_softmax(logits.float(), dim=-1) + log_probs = log_softmax.gather(dim=-1, index=tokens.unsqueeze(-1)) + entropy = -(log_softmax.exp() * log_softmax).sum(dim=-1) + return log_probs, entropy + + +def _make_inputs(requires_grad: bool = False): + torch.manual_seed(1234) + logits = torch.randn(9, 17, dtype=torch.float32) + logits.requires_grad_(requires_grad) + tokens = torch.randint(0, logits.size(-1), (logits.size(0),), dtype=torch.long) + return logits, tokens + + +def test_fused_forward_matches_naive_reference(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs() + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy( + logits.clone(), tokens, _FakeSingleRankGroup(), with_entropy=True + ) + + ref_log_probs, ref_entropy = _naive_log_probs_and_entropy(logits, tokens) + torch.testing.assert_close(log_probs, ref_log_probs, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(entropy, ref_entropy, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("with_entropy", [False, True]) +def test_chunked_matches_unchunked(with_entropy: bool): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs() + with _single_rank_all_reduce(): + full_log_probs, full_entropy = calculate_log_probs_and_entropy( + logits.clone(), tokens, _FakeSingleRankGroup(), with_entropy=with_entropy, chunk_size=-1 + ) + chunk_log_probs, chunk_entropy = calculate_log_probs_and_entropy( + logits.clone(), tokens, _FakeSingleRankGroup(), with_entropy=with_entropy, chunk_size=4 + ) + + torch.testing.assert_close(chunk_log_probs, full_log_probs, atol=1e-5, rtol=1e-5) + if with_entropy: + torch.testing.assert_close(chunk_entropy, full_entropy, atol=1e-5, rtol=1e-5) + else: + assert chunk_entropy is None + + +def test_no_entropy_chunked_backward_preserves_input_grad(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs(requires_grad=True) + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy( + logits, tokens, _FakeSingleRankGroup(), with_entropy=False, chunk_size=4 + ) + assert entropy is None + log_probs.float().sum().backward() + + assert logits.grad is not None + assert torch.isfinite(logits.grad).all() + assert logits.grad.abs().sum() > 0 + + +def test_fused_backward_matches_naive_reference_with_bf16_tolerance(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs(requires_grad=True) + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy(logits, tokens, _FakeSingleRankGroup(), with_entropy=True) + (log_probs.float().sum() + 0.13 * entropy.float().sum()).backward() + + ref_logits = logits.detach().clone().requires_grad_(True) + ref_log_probs, ref_entropy = _naive_log_probs_and_entropy(ref_logits, tokens) + (ref_log_probs.float().sum() + 0.13 * ref_entropy.float().sum()).backward() + + torch.testing.assert_close(logits.grad, ref_logits.grad, atol=4e-3, rtol=4e-3) + + +def _tp2_worker(rank: int, master_port: int, result_path: str) -> None: + import os + + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + dist.init_process_group(backend="gloo", rank=rank, world_size=2) + try: + _install_megatron_stubs() + from slime.utils.ppo_utils import compute_log_probs_and_entropy + + torch.manual_seed(2024) + batch_size = 7 + partition_vocab_size = 5 + full_vocab_size = partition_vocab_size * 2 + full_logits = torch.randn(batch_size, full_vocab_size, dtype=torch.float32) + tokens = torch.randint(0, full_vocab_size, (batch_size,), dtype=torch.long) + start = rank * partition_vocab_size + local_logits = full_logits[:, start : start + partition_vocab_size].contiguous().requires_grad_(True) + + log_probs, entropy = compute_log_probs_and_entropy(local_logits, tokens, dist.group.WORLD) + (log_probs.float().sum() + 0.13 * entropy.float().sum()).backward() + + gathered_grads = [torch.empty_like(local_logits.grad) for _ in range(2)] + dist.all_gather(gathered_grads, local_logits.grad) + + if rank == 0: + ref_logits = full_logits.detach().clone().requires_grad_(True) + ref_log_probs, ref_entropy = _naive_log_probs_and_entropy(ref_logits, tokens) + (ref_log_probs.float().sum() + 0.13 * ref_entropy.float().sum()).backward() + + torch.testing.assert_close(log_probs, ref_log_probs, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(entropy, ref_entropy, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(torch.cat(gathered_grads, dim=-1), ref_logits.grad, atol=4e-3, rtol=4e-3) + with open(result_path, "w") as f: + f.write("ok") + finally: + dist.destroy_process_group() + + +def test_tp2_fused_backward_matches_full_vocab_reference(tmp_path): + result_path = str(tmp_path / "tp2_result.txt") + mp.spawn(_tp2_worker, args=(free_port(), result_path), nprocs=2, join=True) + assert (tmp_path / "tp2_result.txt").read_text() == "ok" + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tools/repro_1951.py b/tools/repro_1951.py new file mode 100644 index 0000000000..2f2e7b6a5c --- /dev/null +++ b/tools/repro_1951.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +"""Synthetic reproducer for issue #1951 log-prob/entropy memory peaks.""" + +from __future__ import annotations + +import argparse +import sys +import types +from contextlib import contextmanager + +import torch +import torch.distributed as dist + + +class _FakeSingleRankGroup: + def rank(self) -> int: + return 0 + + def size(self) -> int: + return 1 + + +def _install_mock_fused_cross_entropy() -> None: + """Install a single-rank Megatron fused CE stand-in for local repro runs.""" + megatron = sys.modules.setdefault("megatron", types.ModuleType("megatron")) + core = sys.modules.setdefault("megatron.core", types.ModuleType("megatron.core")) + fusions = sys.modules.setdefault("megatron.core.fusions", types.ModuleType("megatron.core.fusions")) + fused = types.ModuleType("megatron.core.fusions.fused_cross_entropy") + tensor_parallel = sys.modules.setdefault( + "megatron.core.tensor_parallel", types.ModuleType("megatron.core.tensor_parallel") + ) + utils = types.ModuleType("megatron.core.tensor_parallel.utils") + + class VocabUtility: + @staticmethod + def vocab_range_from_per_partition_vocab_size(partition_vocab_size: int, rank: int, world_size: int): + assert world_size == 1 + assert rank == 0 + return 0, partition_vocab_size + + class _MockVocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits: torch.Tensor, target: torch.Tensor, process_group): + del process_group + logits = logits.float() + logits_max = logits.max(dim=-1, keepdim=True).values + logits.sub_(logits_max) + predicted_logits = logits.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) + torch.exp(logits, out=logits) + sum_exp_logits = logits.sum(dim=-1) + logits.div_(sum_exp_logits.unsqueeze(-1)) + ctx.save_for_backward(logits, target) + return sum_exp_logits.log() - predicted_logits + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + softmax, target = ctx.saved_tensors + grad_input = softmax.clone() + grad_input.scatter_add_( + dim=-1, + index=target.unsqueeze(-1), + src=-torch.ones_like(target, dtype=grad_input.dtype).unsqueeze(-1), + ) + grad_input.mul_(grad_output.unsqueeze(-1)) + return grad_input.to(torch.bfloat16), None, None + + def fused_vocab_parallel_cross_entropy(logits: torch.Tensor, target: torch.Tensor, process_group): + return _MockVocabParallelCrossEntropy.apply(logits, target, process_group) + + fused.fused_vocab_parallel_cross_entropy = fused_vocab_parallel_cross_entropy + utils.VocabUtility = VocabUtility + fusions.fused_cross_entropy = fused + tensor_parallel.utils = utils + core.fusions = fusions + core.tensor_parallel = tensor_parallel + megatron.core = core + sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused + sys.modules["megatron.core.tensor_parallel.utils"] = utils + + +@contextmanager +def _single_rank_all_reduce(): + original_all_reduce = dist.all_reduce + + def all_reduce(tensor, op=None, group=None, async_op=False): + del tensor, op, group + if async_op: + raise NotImplementedError("async all_reduce is not needed by this repro") + return None + + dist.all_reduce = all_reduce + try: + yield + finally: + dist.all_reduce = original_all_reduce + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--batch", "-B", type=int, default=4096, help="Number of token positions.") + parser.add_argument("--vocab", "-V", type=int, default=151936, help="Vocabulary dimension.") + parser.add_argument("--chunk-size", type=int, default=-1, help="Forwarded to calculate_log_probs_and_entropy.") + parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="float32") + parser.add_argument("--with-entropy", action="store_true", help="Also compute entropy.") + parser.add_argument("--backward", action="store_true", help="Run backward through the returned tensors.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--use-real-megatron", action="store_true", help="Do not install the local fused CE mock.") + return parser.parse_args() + + +def _fmt_bytes(value: int) -> str: + return f"{value / 1024**3:.3f} GiB ({value} bytes)" + + +def main() -> None: + args = _parse_args() + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required for peak-memory measurement") + + if not args.use_real_megatron: + _install_mock_fused_cross_entropy() + + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + dtype = getattr(torch, args.dtype) + device = torch.device("cuda") + torch.manual_seed(args.seed) + torch.cuda.reset_peak_memory_stats() + + logits = torch.randn(args.batch, args.vocab, device=device, dtype=dtype) + if args.backward: + logits.requires_grad_(True) + tokens = torch.randint(args.vocab, (args.batch,), device=device) + torch.cuda.synchronize() + + allocated_after_logits = torch.cuda.memory_allocated() + torch.cuda.reset_peak_memory_stats() + + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy( + logits, + tokens, + _FakeSingleRankGroup(), + with_entropy=args.with_entropy, + chunk_size=args.chunk_size, + ) + if args.backward: + loss = log_probs.float().sum() + if entropy is not None: + loss = loss + entropy.float().sum() + loss.backward() + + torch.cuda.synchronize() + peak = torch.cuda.max_memory_allocated() + current = torch.cuda.memory_allocated() + + print(f"shape=({args.batch}, {args.vocab}) dtype={args.dtype} with_entropy={args.with_entropy}") + print(f"chunk_size={args.chunk_size} backward={args.backward} mock_megatron={not args.use_real_megatron}") + print(f"allocated_after_logits={_fmt_bytes(allocated_after_logits)}") + print(f"peak_during_call={_fmt_bytes(peak)}") + print(f"peak_delta_after_logits={_fmt_bytes(peak - allocated_after_logits)}") + print(f"allocated_after_call={_fmt_bytes(current)}") + + +if __name__ == "__main__": + main() From 902a66f28192f88374e0d6950bfcfa909487f007 Mon Sep 17 00:00:00 2001 From: Mantissagithub Date: Wed, 3 Jun 2026 08:17:29 +0530 Subject: [PATCH 2/4] perf(ppo): make fused logprob+entropy backward allocation-lean MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fused backward materialized three new full-vocab [T,1,V] fp32 tensors (log_softmax, entropy_grad, grad_input) on top of the saved softmax, so the unchunked path peaked ~4x [T,V] — worse than the pre-fix two-pass code (H200 B=16384 V=151936: 60.3 vs 51.1 GiB). Reuse the saved softmax buffer in place as the gradient (mirroring Megatron's VocabParallelCrossEntropy.calculate_gradients): at most one extra full-vocab temp, and only when entropy gradient flows. Return a bf16 grad like Megatron's CE (drops the bf16+fp32 mix). H200 peak (with_entropy, backward), now <= main at every chunk size: chunk -1 : 51.1 -> 32.5 GiB (was 60.3, a regression) chunk 1024: 38.0 -> 27.8 GiB chunk 512 : 37.6 -> 27.8 GiB Backward verified vs naive autograd within bf16 tol (TP=1 and TP=2). Also drop a stale INVESTIGATION.md reference in the clone-guard comment. --- slime/utils/ppo_utils.py | 45 ++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index a5d8402691..2203e42a6d 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -200,30 +200,43 @@ def forward(ctx, vocab_parallel_logits: torch.Tensor, target: torch.Tensor, proc local_entropy -= torch.xlogy(softmax_chunk, softmax_chunk).sum(dim=-1) dist.all_reduce(local_entropy, op=dist.ReduceOp.SUM, group=process_group) + # `softmax` is reused in place as the gradient buffer in backward (no double-backward). ctx.save_for_backward(softmax, target_mask, masked_target_1d, local_entropy) return log_prob, local_entropy.squeeze(dim=-1) @staticmethod def backward(ctx, grad_log_prob: torch.Tensor, grad_entropy: torch.Tensor): + # Local grad wrt logit z_j (no cross-rank reduce needed): + # g_j = softmax_j * [ -grad_log_prob - grad_entropy * (entropy + log_softmax_j) ] + # + grad_log_prob * 1{j == target} + # The saved softmax buffer is reused in place as the gradient, so the only extra + # full-vocab allocation is `log_softmax`, and only when entropy gradient flows. softmax, target_mask, masked_target_1d, entropy = ctx.saved_tensors partition_vocab_size = softmax.size(-1) - if grad_entropy is None: - grad_entropy = torch.zeros_like(entropy.squeeze(dim=-1)) - if grad_log_prob is None: - grad_log_prob = torch.zeros_like(entropy) - - log_softmax = torch.where(softmax > 0, softmax.log(), torch.zeros_like(softmax)) - entropy_grad = softmax * ( - grad_entropy.unsqueeze(dim=-1).unsqueeze(dim=-1) * (-entropy.unsqueeze(dim=-1) - log_softmax) - ) - grad_input = softmax * -grad_log_prob.unsqueeze(dim=-1) - grad_input_2d = grad_input.view(-1, partition_vocab_size) - arange_1d = torch.arange(start=0, end=grad_input_2d.size(0), device=grad_input_2d.device) - softmax_update = 1.0 - target_mask.view(-1).float() - grad_input_2d[arange_1d, masked_target_1d] += grad_log_prob.view(-1) * softmax_update + if grad_entropy is not None: + # log_softmax = log(softmax); softmax underflow (==0) -> log(0)=-inf, mapped to 0 + # so the softmax==0 positions contribute nothing (0 * finite). + log_softmax = softmax.log() + log_softmax.nan_to_num_(neginf=0.0) + log_softmax.add_(entropy.unsqueeze(dim=-1)) # entropy + log_softmax + log_softmax.mul_(grad_entropy.unsqueeze(dim=-1).unsqueeze(dim=-1)) + if grad_log_prob is not None: + log_softmax.add_(grad_log_prob.unsqueeze(dim=-1)) + softmax.mul_(log_softmax).neg_() # softmax * [-grad_log_prob - grad_entropy*(H+log p)] + del log_softmax + elif grad_log_prob is not None: + softmax.mul_(grad_log_prob.unsqueeze(dim=-1).neg()) # -softmax * grad_log_prob + else: + return None, None, None + + if grad_log_prob is not None: + grad_2d = softmax.view(-1, partition_vocab_size) + arange_1d = torch.arange(start=0, end=grad_2d.size(0), device=grad_2d.device) + softmax_update = 1.0 - target_mask.view(-1).to(softmax.dtype) + grad_2d[arange_1d, masked_target_1d] += grad_log_prob.reshape(-1) * softmax_update - return grad_input.to(torch.bfloat16) + entropy_grad, None, None + return softmax.to(torch.bfloat16), None, None def compute_log_probs_and_entropy(logits: torch.Tensor, tokens: torch.Tensor, process_group): @@ -682,7 +695,7 @@ def chunked_gae( def _clone_if_grad_tracked(logits: torch.Tensor) -> torch.Tensor: # Megatron-LM's fused CE mutates its float32 input in place (subtract-max, - # exp(out=...), div_; see INVESTIGATION.md). That is safe to hand over directly + # exp(out=...), div_). That is safe to hand over directly # only when autograd will not observe the mutation. When grad is tracked, an # in-place write on a (view of a) grad-requiring tensor corrupts the graph and # raises (issue #1951). Clone exactly when grad is tracked; otherwise pass the From 7f17d547c5f328f66ee029960a5500c4280a6c85 Mon Sep 17 00:00:00 2001 From: Mantissagithub Date: Wed, 10 Jun 2026 01:39:09 +0530 Subject: [PATCH 3/4] test(ppo): cover entropy-only fused backward branch --- slime/utils/ppo_utils.py | 4 ++++ tests/test_logprob_entropy_fused.py | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 2203e42a6d..37bf064bf3 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -164,6 +164,10 @@ class _VocabParallelLogProbsAndEntropy(torch.autograd.Function): def forward(ctx, vocab_parallel_logits: torch.Tensor, target: torch.Tensor, process_group): from megatron.core.tensor_parallel.utils import VocabUtility + # Pass None (not a zero-filled tensor) for an output whose grad does not flow, + # so the single-output backward paths skip a wasted full-vocab allocation. + ctx.set_materialize_grads(False) + logits_max = vocab_parallel_logits.max(dim=-1).values dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) diff --git a/tests/test_logprob_entropy_fused.py b/tests/test_logprob_entropy_fused.py index 4ac0ff7e4f..341f069788 100644 --- a/tests/test_logprob_entropy_fused.py +++ b/tests/test_logprob_entropy_fused.py @@ -186,6 +186,29 @@ def test_fused_backward_matches_naive_reference_with_bf16_tolerance(): torch.testing.assert_close(logits.grad, ref_logits.grad, atol=4e-3, rtol=4e-3) +def test_fused_entropy_only_backward_matches_naive_reference(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs(requires_grad=True) + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy( + logits, tokens, _FakeSingleRankGroup(), with_entropy=True + ) + # Only entropy contributes to the loss, so autograd passes + # grad_log_prob=None into the fused backward, exercising the + # entropy-only branch. + entropy.float().sum().backward() + + ref_logits = logits.detach().clone().requires_grad_(True) + _, ref_entropy = _naive_log_probs_and_entropy(ref_logits, tokens) + ref_entropy.float().sum().backward() + + assert logits.grad is not None + assert torch.isfinite(logits.grad).all() + torch.testing.assert_close(logits.grad, ref_logits.grad, atol=4e-3, rtol=4e-3) + + def _tp2_worker(rank: int, master_port: int, result_path: str) -> None: import os From fda5193c82bfa06d43b371e879d11db78bca9831 Mon Sep 17 00:00:00 2001 From: Mantissagithub Date: Wed, 10 Jun 2026 06:36:14 +0530 Subject: [PATCH 4/4] style(test): satisfy black formatting in entropy-only test --- tests/test_logprob_entropy_fused.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_logprob_entropy_fused.py b/tests/test_logprob_entropy_fused.py index 341f069788..0e579ecd06 100644 --- a/tests/test_logprob_entropy_fused.py +++ b/tests/test_logprob_entropy_fused.py @@ -192,9 +192,7 @@ def test_fused_entropy_only_backward_matches_naive_reference(): logits, tokens = _make_inputs(requires_grad=True) with _single_rank_all_reduce(): - log_probs, entropy = calculate_log_probs_and_entropy( - logits, tokens, _FakeSingleRankGroup(), with_entropy=True - ) + log_probs, entropy = calculate_log_probs_and_entropy(logits, tokens, _FakeSingleRankGroup(), with_entropy=True) # Only entropy contributes to the loss, so autograd passes # grad_log_prob=None into the fused backward, exercising the # entropy-only branch.