diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 1562e6b9ea..43639ab403 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -76,6 +76,7 @@ {'test_file': 'test_metric_report.py', 'num_gpus': 0}, {'test_file': 'test_metric_report_dist.py', 'num_gpus': 0}, {'test_file': 'test_loss_cp_invariance.py', 'num_gpus': 0}, + {'test_file': 'test_logprob_entropy_fused.py', 'num_gpus': 0}, {'test_file': 'test_value_temperature.py', 'num_gpus': 0}, {'test_file': 'test_rm_f1.py', 'num_gpus': 0}, {'test_file': 'test_rm_gpqa.py', 'num_gpus': 0}, diff --git a/examples/retool/retool_qwen3_4b_rl.sh b/examples/retool/retool_qwen3_4b_rl.sh index 32a837f394..01b4e84bd4 100644 --- a/examples/retool/retool_qwen3_4b_rl.sh +++ b/examples/retool/retool_qwen3_4b_rl.sh @@ -75,6 +75,8 @@ PERF_ARGS=( # --micro-batch-size 1 --use-dynamic-batch-size --max-tokens-per-gpu 9216 + # Bound the fused cross-entropy [tokens, vocab] transient that can OOM on long retool traces. + --log-probs-chunk-size 1024 ) GRPO_ARGS=( @@ -153,4 +155,4 @@ ray job submit --address="http://127.0.0.1:8265" \ ${EVAL_ARGS[@]} \ ${SGLANG_ARGS[@]} \ ${MISC_ARGS[@]} \ - ${CUSTOM_ARGS[@]} \ No newline at end of file + ${CUSTOM_ARGS[@]} diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index db6020a94d..9187ec5cf4 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -39,6 +39,23 @@ logger = logging.getLogger(__name__) +def _mem_probe_enabled() -> bool: + return os.environ.get("SLIME_MEM_PROBE", "0") == "1" and torch.cuda.is_available() + + +def _log_train_step_mem_probe(rollout_id: int, step_id: int, start_allocated: int) -> None: + max_allocated = torch.cuda.max_memory_allocated() + logger.info( + "SLIME_MEM_PROBE train_one_step rollout_id=%s step_id=%s " + "allocated_start=%s allocated_peak=%s allocated_peak_delta=%s", + rollout_id, + step_id, + start_allocated, + max_allocated, + max_allocated - start_allocated, + ) + + def _disable_tqdm_for_non_main_rank() -> bool: return not ( mpu.get_data_parallel_rank(with_context_parallel=True) == 0 @@ -455,6 +472,11 @@ def train_one_step( and gradient norm for logging. """ args = get_args() + mem_probe = _mem_probe_enabled() + mem_probe_start_allocated = 0 + if mem_probe: + torch.cuda.reset_peak_memory_stats() + mem_probe_start_allocated = torch.cuda.memory_allocated() # Set grad to zero. for model_chunk in model: @@ -601,7 +623,11 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p cp_size=mpu.get_context_parallel_world_size(), dp_with_cp_group=mpu.get_data_parallel_group(with_context_parallel=True), ) + if mem_probe: + _log_train_step_mem_probe(rollout_id, step_id, mem_probe_start_allocated) return loss_reduced, grad_norm + if mem_probe: + _log_train_step_mem_probe(rollout_id, step_id, mem_probe_start_allocated) return {}, grad_norm diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 2a858e7a3f..37bf064bf3 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -158,44 +158,95 @@ def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: return -fused_vocab_parallel_cross_entropy(logits, tokens, process_group) -# from https://github.com/volcengine/verl/blob/0bdf7f469854815177e73dcfe9e420836c952e6e/verl/utils/megatron/tensor_parallel.py#L99 -class _VocabParallelEntropy(torch.autograd.Function): +class _VocabParallelLogProbsAndEntropy(torch.autograd.Function): @staticmethod - def forward(ctx, vocab_parallel_logits: torch.Tensor, process_group: dist.ProcessGroup) -> torch.Tensor: + def forward(ctx, vocab_parallel_logits: torch.Tensor, target: torch.Tensor, process_group): + from megatron.core.tensor_parallel.utils import VocabUtility - @torch.compile(dynamic=True) - def mul_reduce(a, b): - return (a * b).sum(dim=-1, keepdim=True) + # Pass None (not a zero-filled tensor) for an output whose grad does not flow, + # so the single-output backward paths skip a wasted full-vocab allocation. + ctx.set_materialize_grads(False) - logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values + logits_max = vocab_parallel_logits.max(dim=-1).values dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) - normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max - normalized_exp_logits = normalized_vocab_parallel_logits.exp_() - normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True) - dist.all_reduce(normalized_sum_exp_logits, group=process_group) - softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits) - sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits) - dist.all_reduce(sum_softmax_times_logits, group=process_group) - entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits - ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits) - return entropy.squeeze(dim=-1) + + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size(-1) + vocab_start_index, vocab_end_index = get_vocab_range( + partition_vocab_size, process_group.rank(), process_group.size() + ) + + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size(0), device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + + torch.exp(vocab_parallel_logits, out=vocab_parallel_logits) + sum_exp_logits = vocab_parallel_logits.sum(dim=-1) + dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=process_group) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + + log_sum_exp = sum_exp_logits.log() + log_prob = predicted_logits - log_sum_exp + softmax = vocab_parallel_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + local_entropy = torch.zeros_like(log_sum_exp) + for softmax_chunk in softmax.chunk(max(1, softmax.size(-1) // 8192), dim=-1): + local_entropy -= torch.xlogy(softmax_chunk, softmax_chunk).sum(dim=-1) + dist.all_reduce(local_entropy, op=dist.ReduceOp.SUM, group=process_group) + + # `softmax` is reused in place as the gradient buffer in backward (no double-backward). + ctx.save_for_backward(softmax, target_mask, masked_target_1d, local_entropy) + return log_prob, local_entropy.squeeze(dim=-1) @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors - # reuse softmax_logits as grad - vocab_parallel_logits.sub_(sum_softmax_times_logits) - softmax_logits.mul_(vocab_parallel_logits) - softmax_logits.mul_(grad_output.unsqueeze(dim=-1)) - # recover vocab_parallel_logits - vocab_parallel_logits.add_(sum_softmax_times_logits) - softmax_logits.mul_(-1) - return softmax_logits, None + def backward(ctx, grad_log_prob: torch.Tensor, grad_entropy: torch.Tensor): + # Local grad wrt logit z_j (no cross-rank reduce needed): + # g_j = softmax_j * [ -grad_log_prob - grad_entropy * (entropy + log_softmax_j) ] + # + grad_log_prob * 1{j == target} + # The saved softmax buffer is reused in place as the gradient, so the only extra + # full-vocab allocation is `log_softmax`, and only when entropy gradient flows. + softmax, target_mask, masked_target_1d, entropy = ctx.saved_tensors + partition_vocab_size = softmax.size(-1) + + if grad_entropy is not None: + # log_softmax = log(softmax); softmax underflow (==0) -> log(0)=-inf, mapped to 0 + # so the softmax==0 positions contribute nothing (0 * finite). + log_softmax = softmax.log() + log_softmax.nan_to_num_(neginf=0.0) + log_softmax.add_(entropy.unsqueeze(dim=-1)) # entropy + log_softmax + log_softmax.mul_(grad_entropy.unsqueeze(dim=-1).unsqueeze(dim=-1)) + if grad_log_prob is not None: + log_softmax.add_(grad_log_prob.unsqueeze(dim=-1)) + softmax.mul_(log_softmax).neg_() # softmax * [-grad_log_prob - grad_entropy*(H+log p)] + del log_softmax + elif grad_log_prob is not None: + softmax.mul_(grad_log_prob.unsqueeze(dim=-1).neg()) # -softmax * grad_log_prob + else: + return None, None, None + + if grad_log_prob is not None: + grad_2d = softmax.view(-1, partition_vocab_size) + arange_1d = torch.arange(start=0, end=grad_2d.size(0), device=grad_2d.device) + softmax_update = 1.0 - target_mask.view(-1).to(softmax.dtype) + grad_2d[arange_1d, masked_target_1d] += grad_log_prob.reshape(-1) * softmax_update + return softmax.to(torch.bfloat16), None, None -def compute_entropy_from_logits(logits: torch.Tensor, process_group) -> torch.Tensor: - return _VocabParallelEntropy.apply(logits, process_group) + +def compute_log_probs_and_entropy(logits: torch.Tensor, tokens: torch.Tensor, process_group): + logits = logits.unsqueeze(1) + tokens = tokens.unsqueeze(1) + return _VocabParallelLogProbsAndEntropy.apply(logits, tokens, process_group) def get_grpo_returns( @@ -646,6 +697,18 @@ def chunked_gae( return advantages, returns +def _clone_if_grad_tracked(logits: torch.Tensor) -> torch.Tensor: + # Megatron-LM's fused CE mutates its float32 input in place (subtract-max, + # exp(out=...), div_). That is safe to hand over directly + # only when autograd will not observe the mutation. When grad is tracked, an + # in-place write on a (view of a) grad-requiring tensor corrupts the graph and + # raises (issue #1951). Clone exactly when grad is tracked; otherwise pass the + # tensor through so the no-grad ref/old-logprob path keeps its peak-memory win. + if torch.is_grad_enabled() and logits.requires_grad: + return logits.clone() + return logits + + def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1): logits = logits.contiguous() entropy = None @@ -657,22 +720,29 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool if with_entropy: entropys = [] - for logits_chunk in logits_chunks: - entropy_input = logits_chunk.clone() - entropys.append(compute_entropy_from_logits(entropy_input, tp_group)) + log_probs = [] + for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + # The fused helper computes log-probs and entropy from one tensor, + # replacing the two separate destructive passes (and two clones). + log_prob, entropy_chunk = compute_log_probs_and_entropy( + _clone_if_grad_tracked(logits_chunk), tokens_chunk, tp_group + ) + log_probs.append(log_prob) + entropys.append(entropy_chunk) entropy = torch.cat(entropys, dim=0) - - log_probs = [] - for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): - log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group) - log_probs.append(log_prob) + else: + log_probs = [] + for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + log_prob = compute_log_probs(_clone_if_grad_tracked(logits_chunk), tokens_chunk, tp_group) + log_probs.append(log_prob) log_prob = torch.cat(log_probs, dim=0) else: if with_entropy: - entropy_input = logits.clone() - entropy = compute_entropy_from_logits(entropy_input, tp_group) - - log_prob = compute_log_probs(logits.clone(), tokens, tp_group) + # The fused helper computes log-probs and entropy from one tensor, + # replacing the two separate destructive passes (and two clones). + log_prob, entropy = compute_log_probs_and_entropy(_clone_if_grad_tracked(logits), tokens, tp_group) + else: + log_prob = compute_log_probs(_clone_if_grad_tracked(logits), tokens, tp_group) else: log_prob = logits.new_zeros((0,)) if with_entropy: diff --git a/tests/test_logprob_entropy_fused.py b/tests/test_logprob_entropy_fused.py new file mode 100644 index 0000000000..0e579ecd06 --- /dev/null +++ b/tests/test_logprob_entropy_fused.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import sys +import types +from contextlib import contextmanager +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from _cp_dist_helpers import free_port + + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +NUM_GPUS = 0 + + +class _FakeSingleRankGroup: + def rank(self) -> int: + return 0 + + def size(self) -> int: + return 1 + + +def _install_megatron_stubs() -> None: + megatron = sys.modules.setdefault("megatron", types.ModuleType("megatron")) + core = sys.modules.setdefault("megatron.core", types.ModuleType("megatron.core")) + fusions = sys.modules.setdefault("megatron.core.fusions", types.ModuleType("megatron.core.fusions")) + fused = types.ModuleType("megatron.core.fusions.fused_cross_entropy") + tensor_parallel = sys.modules.setdefault( + "megatron.core.tensor_parallel", types.ModuleType("megatron.core.tensor_parallel") + ) + utils = types.ModuleType("megatron.core.tensor_parallel.utils") + + class VocabUtility: + @staticmethod + def vocab_range_from_per_partition_vocab_size(partition_vocab_size: int, rank: int, world_size: int): + assert world_size > 0 + assert 0 <= rank < world_size + start = rank * partition_vocab_size + return start, start + partition_vocab_size + + class _MockVocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits: torch.Tensor, target: torch.Tensor, process_group): + del process_group + logits_max = logits.max(dim=-1, keepdim=True).values + logits.sub_(logits_max) + predicted_logits = logits.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) + torch.exp(logits, out=logits) + sum_exp_logits = logits.sum(dim=-1) + logits.div_(sum_exp_logits.unsqueeze(-1)) + ctx.save_for_backward(logits, target) + return sum_exp_logits.log() - predicted_logits + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + softmax, target = ctx.saved_tensors + grad_input = softmax.clone() + grad_input.scatter_add_( + dim=-1, + index=target.unsqueeze(-1), + src=-torch.ones_like(target, dtype=grad_input.dtype).unsqueeze(-1), + ) + grad_input.mul_(grad_output.unsqueeze(-1)) + return grad_input.to(torch.bfloat16), None, None + + def fused_vocab_parallel_cross_entropy(logits: torch.Tensor, target: torch.Tensor, process_group): + return _MockVocabParallelCrossEntropy.apply(logits, target, process_group) + + fused.fused_vocab_parallel_cross_entropy = fused_vocab_parallel_cross_entropy + utils.VocabUtility = VocabUtility + fusions.fused_cross_entropy = fused + tensor_parallel.utils = utils + core.fusions = fusions + core.tensor_parallel = tensor_parallel + megatron.core = core + sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused + sys.modules["megatron.core.tensor_parallel.utils"] = utils + + +@contextmanager +def _single_rank_all_reduce(): + original_all_reduce = dist.all_reduce + + def all_reduce(tensor, op=None, group=None, async_op=False): + del tensor, op, group + if async_op: + raise NotImplementedError("async all_reduce is not needed by this test") + return None + + dist.all_reduce = all_reduce + try: + yield + finally: + dist.all_reduce = original_all_reduce + + +def _naive_log_probs_and_entropy(logits: torch.Tensor, tokens: torch.Tensor): + log_softmax = torch.log_softmax(logits.float(), dim=-1) + log_probs = log_softmax.gather(dim=-1, index=tokens.unsqueeze(-1)) + entropy = -(log_softmax.exp() * log_softmax).sum(dim=-1) + return log_probs, entropy + + +def _make_inputs(requires_grad: bool = False): + torch.manual_seed(1234) + logits = torch.randn(9, 17, dtype=torch.float32) + logits.requires_grad_(requires_grad) + tokens = torch.randint(0, logits.size(-1), (logits.size(0),), dtype=torch.long) + return logits, tokens + + +def test_fused_forward_matches_naive_reference(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs() + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy( + logits.clone(), tokens, _FakeSingleRankGroup(), with_entropy=True + ) + + ref_log_probs, ref_entropy = _naive_log_probs_and_entropy(logits, tokens) + torch.testing.assert_close(log_probs, ref_log_probs, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(entropy, ref_entropy, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("with_entropy", [False, True]) +def test_chunked_matches_unchunked(with_entropy: bool): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs() + with _single_rank_all_reduce(): + full_log_probs, full_entropy = calculate_log_probs_and_entropy( + logits.clone(), tokens, _FakeSingleRankGroup(), with_entropy=with_entropy, chunk_size=-1 + ) + chunk_log_probs, chunk_entropy = calculate_log_probs_and_entropy( + logits.clone(), tokens, _FakeSingleRankGroup(), with_entropy=with_entropy, chunk_size=4 + ) + + torch.testing.assert_close(chunk_log_probs, full_log_probs, atol=1e-5, rtol=1e-5) + if with_entropy: + torch.testing.assert_close(chunk_entropy, full_entropy, atol=1e-5, rtol=1e-5) + else: + assert chunk_entropy is None + + +def test_no_entropy_chunked_backward_preserves_input_grad(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs(requires_grad=True) + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy( + logits, tokens, _FakeSingleRankGroup(), with_entropy=False, chunk_size=4 + ) + assert entropy is None + log_probs.float().sum().backward() + + assert logits.grad is not None + assert torch.isfinite(logits.grad).all() + assert logits.grad.abs().sum() > 0 + + +def test_fused_backward_matches_naive_reference_with_bf16_tolerance(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs(requires_grad=True) + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy(logits, tokens, _FakeSingleRankGroup(), with_entropy=True) + (log_probs.float().sum() + 0.13 * entropy.float().sum()).backward() + + ref_logits = logits.detach().clone().requires_grad_(True) + ref_log_probs, ref_entropy = _naive_log_probs_and_entropy(ref_logits, tokens) + (ref_log_probs.float().sum() + 0.13 * ref_entropy.float().sum()).backward() + + torch.testing.assert_close(logits.grad, ref_logits.grad, atol=4e-3, rtol=4e-3) + + +def test_fused_entropy_only_backward_matches_naive_reference(): + _install_megatron_stubs() + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + logits, tokens = _make_inputs(requires_grad=True) + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy(logits, tokens, _FakeSingleRankGroup(), with_entropy=True) + # Only entropy contributes to the loss, so autograd passes + # grad_log_prob=None into the fused backward, exercising the + # entropy-only branch. + entropy.float().sum().backward() + + ref_logits = logits.detach().clone().requires_grad_(True) + _, ref_entropy = _naive_log_probs_and_entropy(ref_logits, tokens) + ref_entropy.float().sum().backward() + + assert logits.grad is not None + assert torch.isfinite(logits.grad).all() + torch.testing.assert_close(logits.grad, ref_logits.grad, atol=4e-3, rtol=4e-3) + + +def _tp2_worker(rank: int, master_port: int, result_path: str) -> None: + import os + + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + dist.init_process_group(backend="gloo", rank=rank, world_size=2) + try: + _install_megatron_stubs() + from slime.utils.ppo_utils import compute_log_probs_and_entropy + + torch.manual_seed(2024) + batch_size = 7 + partition_vocab_size = 5 + full_vocab_size = partition_vocab_size * 2 + full_logits = torch.randn(batch_size, full_vocab_size, dtype=torch.float32) + tokens = torch.randint(0, full_vocab_size, (batch_size,), dtype=torch.long) + start = rank * partition_vocab_size + local_logits = full_logits[:, start : start + partition_vocab_size].contiguous().requires_grad_(True) + + log_probs, entropy = compute_log_probs_and_entropy(local_logits, tokens, dist.group.WORLD) + (log_probs.float().sum() + 0.13 * entropy.float().sum()).backward() + + gathered_grads = [torch.empty_like(local_logits.grad) for _ in range(2)] + dist.all_gather(gathered_grads, local_logits.grad) + + if rank == 0: + ref_logits = full_logits.detach().clone().requires_grad_(True) + ref_log_probs, ref_entropy = _naive_log_probs_and_entropy(ref_logits, tokens) + (ref_log_probs.float().sum() + 0.13 * ref_entropy.float().sum()).backward() + + torch.testing.assert_close(log_probs, ref_log_probs, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(entropy, ref_entropy, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(torch.cat(gathered_grads, dim=-1), ref_logits.grad, atol=4e-3, rtol=4e-3) + with open(result_path, "w") as f: + f.write("ok") + finally: + dist.destroy_process_group() + + +def test_tp2_fused_backward_matches_full_vocab_reference(tmp_path): + result_path = str(tmp_path / "tp2_result.txt") + mp.spawn(_tp2_worker, args=(free_port(), result_path), nprocs=2, join=True) + assert (tmp_path / "tp2_result.txt").read_text() == "ok" + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tools/repro_1951.py b/tools/repro_1951.py new file mode 100644 index 0000000000..2f2e7b6a5c --- /dev/null +++ b/tools/repro_1951.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +"""Synthetic reproducer for issue #1951 log-prob/entropy memory peaks.""" + +from __future__ import annotations + +import argparse +import sys +import types +from contextlib import contextmanager + +import torch +import torch.distributed as dist + + +class _FakeSingleRankGroup: + def rank(self) -> int: + return 0 + + def size(self) -> int: + return 1 + + +def _install_mock_fused_cross_entropy() -> None: + """Install a single-rank Megatron fused CE stand-in for local repro runs.""" + megatron = sys.modules.setdefault("megatron", types.ModuleType("megatron")) + core = sys.modules.setdefault("megatron.core", types.ModuleType("megatron.core")) + fusions = sys.modules.setdefault("megatron.core.fusions", types.ModuleType("megatron.core.fusions")) + fused = types.ModuleType("megatron.core.fusions.fused_cross_entropy") + tensor_parallel = sys.modules.setdefault( + "megatron.core.tensor_parallel", types.ModuleType("megatron.core.tensor_parallel") + ) + utils = types.ModuleType("megatron.core.tensor_parallel.utils") + + class VocabUtility: + @staticmethod + def vocab_range_from_per_partition_vocab_size(partition_vocab_size: int, rank: int, world_size: int): + assert world_size == 1 + assert rank == 0 + return 0, partition_vocab_size + + class _MockVocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits: torch.Tensor, target: torch.Tensor, process_group): + del process_group + logits = logits.float() + logits_max = logits.max(dim=-1, keepdim=True).values + logits.sub_(logits_max) + predicted_logits = logits.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) + torch.exp(logits, out=logits) + sum_exp_logits = logits.sum(dim=-1) + logits.div_(sum_exp_logits.unsqueeze(-1)) + ctx.save_for_backward(logits, target) + return sum_exp_logits.log() - predicted_logits + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + softmax, target = ctx.saved_tensors + grad_input = softmax.clone() + grad_input.scatter_add_( + dim=-1, + index=target.unsqueeze(-1), + src=-torch.ones_like(target, dtype=grad_input.dtype).unsqueeze(-1), + ) + grad_input.mul_(grad_output.unsqueeze(-1)) + return grad_input.to(torch.bfloat16), None, None + + def fused_vocab_parallel_cross_entropy(logits: torch.Tensor, target: torch.Tensor, process_group): + return _MockVocabParallelCrossEntropy.apply(logits, target, process_group) + + fused.fused_vocab_parallel_cross_entropy = fused_vocab_parallel_cross_entropy + utils.VocabUtility = VocabUtility + fusions.fused_cross_entropy = fused + tensor_parallel.utils = utils + core.fusions = fusions + core.tensor_parallel = tensor_parallel + megatron.core = core + sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused + sys.modules["megatron.core.tensor_parallel.utils"] = utils + + +@contextmanager +def _single_rank_all_reduce(): + original_all_reduce = dist.all_reduce + + def all_reduce(tensor, op=None, group=None, async_op=False): + del tensor, op, group + if async_op: + raise NotImplementedError("async all_reduce is not needed by this repro") + return None + + dist.all_reduce = all_reduce + try: + yield + finally: + dist.all_reduce = original_all_reduce + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--batch", "-B", type=int, default=4096, help="Number of token positions.") + parser.add_argument("--vocab", "-V", type=int, default=151936, help="Vocabulary dimension.") + parser.add_argument("--chunk-size", type=int, default=-1, help="Forwarded to calculate_log_probs_and_entropy.") + parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="float32") + parser.add_argument("--with-entropy", action="store_true", help="Also compute entropy.") + parser.add_argument("--backward", action="store_true", help="Run backward through the returned tensors.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--use-real-megatron", action="store_true", help="Do not install the local fused CE mock.") + return parser.parse_args() + + +def _fmt_bytes(value: int) -> str: + return f"{value / 1024**3:.3f} GiB ({value} bytes)" + + +def main() -> None: + args = _parse_args() + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required for peak-memory measurement") + + if not args.use_real_megatron: + _install_mock_fused_cross_entropy() + + from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + dtype = getattr(torch, args.dtype) + device = torch.device("cuda") + torch.manual_seed(args.seed) + torch.cuda.reset_peak_memory_stats() + + logits = torch.randn(args.batch, args.vocab, device=device, dtype=dtype) + if args.backward: + logits.requires_grad_(True) + tokens = torch.randint(args.vocab, (args.batch,), device=device) + torch.cuda.synchronize() + + allocated_after_logits = torch.cuda.memory_allocated() + torch.cuda.reset_peak_memory_stats() + + with _single_rank_all_reduce(): + log_probs, entropy = calculate_log_probs_and_entropy( + logits, + tokens, + _FakeSingleRankGroup(), + with_entropy=args.with_entropy, + chunk_size=args.chunk_size, + ) + if args.backward: + loss = log_probs.float().sum() + if entropy is not None: + loss = loss + entropy.float().sum() + loss.backward() + + torch.cuda.synchronize() + peak = torch.cuda.max_memory_allocated() + current = torch.cuda.memory_allocated() + + print(f"shape=({args.batch}, {args.vocab}) dtype={args.dtype} with_entropy={args.with_entropy}") + print(f"chunk_size={args.chunk_size} backward={args.backward} mock_megatron={not args.use_real_megatron}") + print(f"allocated_after_logits={_fmt_bytes(allocated_after_logits)}") + print(f"peak_during_call={_fmt_bytes(peak)}") + print(f"peak_delta_after_logits={_fmt_bytes(peak - allocated_after_logits)}") + print(f"allocated_after_call={_fmt_bytes(current)}") + + +if __name__ == "__main__": + main()