Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
{'test_file': 'test_metric_report.py', 'num_gpus': 0},
{'test_file': 'test_metric_report_dist.py', 'num_gpus': 0},
{'test_file': 'test_loss_cp_invariance.py', 'num_gpus': 0},
{'test_file': 'test_logprob_entropy_fused.py', 'num_gpus': 0},
{'test_file': 'test_value_temperature.py', 'num_gpus': 0},
{'test_file': 'test_rm_f1.py', 'num_gpus': 0},
{'test_file': 'test_rm_gpqa.py', 'num_gpus': 0},
Expand Down
4 changes: 3 additions & 1 deletion examples/retool/retool_qwen3_4b_rl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ PERF_ARGS=(
# --micro-batch-size 1
--use-dynamic-batch-size
--max-tokens-per-gpu 9216
# Bound the fused cross-entropy [tokens, vocab] transient that can OOM on long retool traces.
--log-probs-chunk-size 1024
)

GRPO_ARGS=(
Expand Down Expand Up @@ -153,4 +155,4 @@ ray job submit --address="http://127.0.0.1:8265" \
${EVAL_ARGS[@]} \
${SGLANG_ARGS[@]} \
${MISC_ARGS[@]} \
${CUSTOM_ARGS[@]}
${CUSTOM_ARGS[@]}
26 changes: 26 additions & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@
logger = logging.getLogger(__name__)


def _mem_probe_enabled() -> bool:
return os.environ.get("SLIME_MEM_PROBE", "0") == "1" and torch.cuda.is_available()


def _log_train_step_mem_probe(rollout_id: int, step_id: int, start_allocated: int) -> None:
max_allocated = torch.cuda.max_memory_allocated()
logger.info(
"SLIME_MEM_PROBE train_one_step rollout_id=%s step_id=%s "
"allocated_start=%s allocated_peak=%s allocated_peak_delta=%s",
rollout_id,
step_id,
start_allocated,
max_allocated,
max_allocated - start_allocated,
)


def _disable_tqdm_for_non_main_rank() -> bool:
return not (
mpu.get_data_parallel_rank(with_context_parallel=True) == 0
Expand Down Expand Up @@ -455,6 +472,11 @@ def train_one_step(
and gradient norm for logging.
"""
args = get_args()
mem_probe = _mem_probe_enabled()
mem_probe_start_allocated = 0
if mem_probe:
torch.cuda.reset_peak_memory_stats()
mem_probe_start_allocated = torch.cuda.memory_allocated()

# Set grad to zero.
for model_chunk in model:
Expand Down Expand Up @@ -601,7 +623,11 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
cp_size=mpu.get_context_parallel_world_size(),
dp_with_cp_group=mpu.get_data_parallel_group(with_context_parallel=True),
)
if mem_probe:
_log_train_step_mem_probe(rollout_id, step_id, mem_probe_start_allocated)
return loss_reduced, grad_norm
if mem_probe:
_log_train_step_mem_probe(rollout_id, step_id, mem_probe_start_allocated)
return {}, grad_norm


Expand Down
152 changes: 111 additions & 41 deletions slime/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,44 +158,95 @@ def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group:
return -fused_vocab_parallel_cross_entropy(logits, tokens, process_group)


# from https://github.com/volcengine/verl/blob/0bdf7f469854815177e73dcfe9e420836c952e6e/verl/utils/megatron/tensor_parallel.py#L99
class _VocabParallelEntropy(torch.autograd.Function):
class _VocabParallelLogProbsAndEntropy(torch.autograd.Function):

@staticmethod
def forward(ctx, vocab_parallel_logits: torch.Tensor, process_group: dist.ProcessGroup) -> torch.Tensor:
def forward(ctx, vocab_parallel_logits: torch.Tensor, target: torch.Tensor, process_group):
from megatron.core.tensor_parallel.utils import VocabUtility

@torch.compile(dynamic=True)
def mul_reduce(a, b):
return (a * b).sum(dim=-1, keepdim=True)
# Pass None (not a zero-filled tensor) for an output whose grad does not flow,
# so the single-output backward paths skip a wasted full-vocab allocation.
ctx.set_materialize_grads(False)

logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values
logits_max = vocab_parallel_logits.max(dim=-1).values
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group)
normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max
normalized_exp_logits = normalized_vocab_parallel_logits.exp_()
normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)
dist.all_reduce(normalized_sum_exp_logits, group=process_group)
softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits)
sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits)
dist.all_reduce(sum_softmax_times_logits, group=process_group)
entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits
ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)
return entropy.squeeze(dim=-1)

get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size(-1)
vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, process_group.rank(), process_group.size()
)

vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0

logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size(0), device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0

torch.exp(vocab_parallel_logits, out=vocab_parallel_logits)
sum_exp_logits = vocab_parallel_logits.sum(dim=-1)
dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=process_group)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)

log_sum_exp = sum_exp_logits.log()
log_prob = predicted_logits - log_sum_exp
softmax = vocab_parallel_logits.div_(sum_exp_logits.unsqueeze(dim=-1))

local_entropy = torch.zeros_like(log_sum_exp)
for softmax_chunk in softmax.chunk(max(1, softmax.size(-1) // 8192), dim=-1):
local_entropy -= torch.xlogy(softmax_chunk, softmax_chunk).sum(dim=-1)
dist.all_reduce(local_entropy, op=dist.ReduceOp.SUM, group=process_group)

# `softmax` is reused in place as the gradient buffer in backward (no double-backward).
ctx.save_for_backward(softmax, target_mask, masked_target_1d, local_entropy)
return log_prob, local_entropy.squeeze(dim=-1)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors
# reuse softmax_logits as grad
vocab_parallel_logits.sub_(sum_softmax_times_logits)
softmax_logits.mul_(vocab_parallel_logits)
softmax_logits.mul_(grad_output.unsqueeze(dim=-1))
# recover vocab_parallel_logits
vocab_parallel_logits.add_(sum_softmax_times_logits)
softmax_logits.mul_(-1)
return softmax_logits, None
def backward(ctx, grad_log_prob: torch.Tensor, grad_entropy: torch.Tensor):
# Local grad wrt logit z_j (no cross-rank reduce needed):
# g_j = softmax_j * [ -grad_log_prob - grad_entropy * (entropy + log_softmax_j) ]
# + grad_log_prob * 1{j == target}
# The saved softmax buffer is reused in place as the gradient, so the only extra
# full-vocab allocation is `log_softmax`, and only when entropy gradient flows.
softmax, target_mask, masked_target_1d, entropy = ctx.saved_tensors
partition_vocab_size = softmax.size(-1)

if grad_entropy is not None:
# log_softmax = log(softmax); softmax underflow (==0) -> log(0)=-inf, mapped to 0
# so the softmax==0 positions contribute nothing (0 * finite).
log_softmax = softmax.log()
log_softmax.nan_to_num_(neginf=0.0)
log_softmax.add_(entropy.unsqueeze(dim=-1)) # entropy + log_softmax
log_softmax.mul_(grad_entropy.unsqueeze(dim=-1).unsqueeze(dim=-1))
if grad_log_prob is not None:
log_softmax.add_(grad_log_prob.unsqueeze(dim=-1))
softmax.mul_(log_softmax).neg_() # softmax * [-grad_log_prob - grad_entropy*(H+log p)]
del log_softmax
elif grad_log_prob is not None:
softmax.mul_(grad_log_prob.unsqueeze(dim=-1).neg()) # -softmax * grad_log_prob
else:
return None, None, None

if grad_log_prob is not None:
grad_2d = softmax.view(-1, partition_vocab_size)
arange_1d = torch.arange(start=0, end=grad_2d.size(0), device=grad_2d.device)
softmax_update = 1.0 - target_mask.view(-1).to(softmax.dtype)
grad_2d[arange_1d, masked_target_1d] += grad_log_prob.reshape(-1) * softmax_update

return softmax.to(torch.bfloat16), None, None

def compute_entropy_from_logits(logits: torch.Tensor, process_group) -> torch.Tensor:
return _VocabParallelEntropy.apply(logits, process_group)

def compute_log_probs_and_entropy(logits: torch.Tensor, tokens: torch.Tensor, process_group):
logits = logits.unsqueeze(1)
tokens = tokens.unsqueeze(1)
return _VocabParallelLogProbsAndEntropy.apply(logits, tokens, process_group)


def get_grpo_returns(
Expand Down Expand Up @@ -646,6 +697,18 @@ def chunked_gae(
return advantages, returns


def _clone_if_grad_tracked(logits: torch.Tensor) -> torch.Tensor:
# Megatron-LM's fused CE mutates its float32 input in place (subtract-max,
# exp(out=...), div_). That is safe to hand over directly
# only when autograd will not observe the mutation. When grad is tracked, an
# in-place write on a (view of a) grad-requiring tensor corrupts the graph and
# raises (issue #1951). Clone exactly when grad is tracked; otherwise pass the
# tensor through so the no-grad ref/old-logprob path keeps its peak-memory win.
if torch.is_grad_enabled() and logits.requires_grad:
return logits.clone()
return logits


def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1):
logits = logits.contiguous()
entropy = None
Expand All @@ -657,22 +720,29 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool

if with_entropy:
entropys = []
for logits_chunk in logits_chunks:
entropy_input = logits_chunk.clone()
entropys.append(compute_entropy_from_logits(entropy_input, tp_group))
log_probs = []
for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
# The fused helper computes log-probs and entropy from one tensor,
# replacing the two separate destructive passes (and two clones).
log_prob, entropy_chunk = compute_log_probs_and_entropy(
_clone_if_grad_tracked(logits_chunk), tokens_chunk, tp_group
)
log_probs.append(log_prob)
entropys.append(entropy_chunk)
entropy = torch.cat(entropys, dim=0)

log_probs = []
for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group)
log_probs.append(log_prob)
else:
log_probs = []
for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
log_prob = compute_log_probs(_clone_if_grad_tracked(logits_chunk), tokens_chunk, tp_group)
log_probs.append(log_prob)
log_prob = torch.cat(log_probs, dim=0)
else:
if with_entropy:
entropy_input = logits.clone()
entropy = compute_entropy_from_logits(entropy_input, tp_group)

log_prob = compute_log_probs(logits.clone(), tokens, tp_group)
# The fused helper computes log-probs and entropy from one tensor,
# replacing the two separate destructive passes (and two clones).
log_prob, entropy = compute_log_probs_and_entropy(_clone_if_grad_tracked(logits), tokens, tp_group)
else:
log_prob = compute_log_probs(_clone_if_grad_tracked(logits), tokens, tp_group)
else:
log_prob = logits.new_zeros((0,))
if with_entropy:
Expand Down
Loading
Loading