Use put_along_axis for Paddle routing metadata#16
Merged
Conversation
Replace the temporary scatter_nd_add construction in Paddle HybridEP indices_to_map with put_along_axis, matching the upstream scatter semantics while keeping uint8 routing storage because Paddle does not provide a CUDA bool put_along_axis kernel. Validation:\n- 2x8 A1B topk=2 DeepEP vs HybridEP 50-step bitwise check: final_layernorm_output MD5 matched 100/100 for ranks 0 and 8; tr_loss_before_reduce matched 50/50 for ranks 0 and 8. Co-authored-by: Codex <noreply@openai.com>
There was a problem hiding this comment.
Pull request overview
This PR updates the Paddle HybridEP indices_to_map() helper to build dense routing metadata using paddle.put_along_axis instead of the previous scatter_nd_add-based construction, aiming to reduce index expansion and intermediate tensor overhead under Paddle compat.
Changes:
- Replaced the
scatter_nd_add-based dense routing map/prob construction withpaddle.put_along_axis. - Reused a single
topk_idxint64 conversion to avoid repeated casts. - Kept routing map materialization via
uint8 -> boolto work around missing CUDA boolput_along_axissupport.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
背景
Paddle 版 HybridEP 的
indices_to_map()之前通过scatter_nd_add构造 dense routing map/probs。这是为了绕过 Paddle compat 下torch.scatter的兼容问题,但实现比上游 scatter 语义更绕,也引入了额外的索引展开和临时 tensor。修改
indices_to_map()中的scatter_nd_add路径改为paddle.put_along_axis。topk_idx.to(torch.int64),避免重复转换。uint8 -> bool的 routing map 写法,因为 Paddle 当前没有 CUDA boolput_along_axiskernel,不能直接恢复成上游dtype=torch.boolscatter 写法。device="cuda",更接近上游写法。验证
逐位对齐:
final_layernorm_outputMD5:rank 0/8 均ordered_unique_neq = 0/100。tr_loss_before_reduce:rank 0/8 均paired_neq = 0/50。性能验证:
put_along_axis与旧scatter_nd_add实现,统计 step 51-100。global_steps_per_second:0.359827vs0.349351,约+2.91%。tokens_per_sec_per_card:5895.404vs5723.763,约+2.91%。dispatch/combine时间整体接近,端到端吞吐有小幅提升。