Skip to content

perf(megatron-loss): scale logits per-chunk to avoid OOM#2010

Open
Yangruipis wants to merge 2 commits into
THUDM:mainfrom
redai-infra:fix/wuhuan/div_oom
Open

perf(megatron-loss): scale logits per-chunk to avoid OOM#2010
Yangruipis wants to merge 2 commits into
THUDM:mainfrom
redai-infra:fix/wuhuan/div_oom

Conversation

@Yangruipis

Copy link
Copy Markdown
Contributor
# ⚡ Performance

## Move rollout_temperature division into per-chunk yield in get_responses

- Remove full-tensor `logits.div(rollout_temperature)` that allocated a
  duplicate `[T, V]` fp32 buffer (~16 GiB on Qwen3 with long packed
  sequences), doubling loss-step peak memory and triggering OOM under
  allocator fragmentation
- Apply the scalar division to each `logits_chunk` right before yielding,
  so allocations are bounded by per-sample response size and happen
  incrementally instead of as a single giant contiguous block
- Numerically equivalent across all four chunking paths (cp_size==1 RL,
  SFT, allgather_cp, zigzag CP) since scalar division commutes with
  slicing and concatenation

- Remove full-tensor `logits.div(rollout_temperature)` that allocated a
	duplicate `[T, V]` fp32 buffer (~16 GiB on Qwen3 with long packed
	sequences), doubling loss-step peak memory and triggering OOM under
	allocator fragmentation
- Apply the scalar division to each `logits_chunk` right before yielding,
	so allocations are bounded by per-sample response size and happen
	incrementally instead of as a single giant contiguous block
- Numerically equivalent across all four chunking paths (cp_size==1 RL,
	SFT, allgather_cp, zigzag CP) since scalar division commutes with
	slicing and concatenation
@Yangruipis Yangruipis closed this Jun 2, 2026
@Yangruipis Yangruipis reopened this Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant