From 7df9698471a2b04cd8c7e8135814a115a5852621 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Tue, 19 May 2026 14:43:21 +0800 Subject: [PATCH] Use put_along_axis for Paddle routing metadata Replace the temporary scatter_nd_add construction in Paddle HybridEP indices_to_map with put_along_axis, matching the upstream scatter semantics while keeping uint8 routing storage because Paddle does not provide a CUDA bool put_along_axis kernel. Validation:\n- 2x8 A1B topk=2 DeepEP vs HybridEP 50-step bitwise check: final_layernorm_output MD5 matched 100/100 for ranks 0 and 8; tr_loss_before_reduce matched 50/50 for ranks 0 and 8. Co-authored-by: Codex --- deep_ep/hybrid_ep_buffer.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/deep_ep/hybrid_ep_buffer.py b/deep_ep/hybrid_ep_buffer.py index b3b075bb..f51c86da 100644 --- a/deep_ep/hybrid_ep_buffer.py +++ b/deep_ep/hybrid_ep_buffer.py @@ -25,30 +25,22 @@ def indices_to_map( """ # Generate the routing map and the probs according to the topk_idx and topk_weights. assert topk_idx is not None + topk_idx = topk_idx.to(torch.int64) routing_map = torch.zeros( - num_of_tokens, num_of_experts, dtype=torch.bool - ).cuda() - # routing_map = routing_map.scatter(1, topk_idx.to(torch.int64), 1).bool() - batch_size = routing_map.shape[0] - num_experts = routing_map.shape[1] - topk = topk_idx.shape[1] - row_indices = paddle.arange(0, batch_size, dtype=topk_idx.dtype).unsqueeze(1).expand([batch_size, topk]) - indices = paddle.stack([row_indices, topk_idx], axis=2).reshape([-1, 2]) - - tmp = paddle.zeros([batch_size, num_experts], dtype='float32') - ones = paddle.ones([indices.shape[0],], dtype='float32') - tmp = paddle.scatter_nd_add(tmp, indices, ones) - - routing_map = (tmp > 0).astype('bool') + num_of_tokens, num_of_experts, device="cuda", dtype=torch.uint8 + ) + routing_map = paddle.put_along_axis( + routing_map, + topk_idx, + torch.ones(topk_idx.shape, device="cuda", dtype=torch.uint8), + axis=1, + ).bool() if topk_weights is not None: probs = torch.zeros( - num_of_tokens, num_of_experts, dtype=torch.float32 - ).cuda() - updates = topk_weights.reshape([-1]) - tmp = paddle.zeros_like(probs) - tmp = paddle.scatter_nd_add(tmp, indices, updates) - probs = tmp + num_of_tokens, num_of_experts, device="cuda", dtype=torch.float32 + ) + probs = paddle.put_along_axis(probs, topk_idx, topk_weights, axis=1) else: probs = None return routing_map, probs