diff --git a/deep_ep/hybrid_ep_buffer.py b/deep_ep/hybrid_ep_buffer.py index b3b075bb..f51c86da 100644 --- a/deep_ep/hybrid_ep_buffer.py +++ b/deep_ep/hybrid_ep_buffer.py @@ -25,30 +25,22 @@ def indices_to_map( """ # Generate the routing map and the probs according to the topk_idx and topk_weights. assert topk_idx is not None + topk_idx = topk_idx.to(torch.int64) routing_map = torch.zeros( - num_of_tokens, num_of_experts, dtype=torch.bool - ).cuda() - # routing_map = routing_map.scatter(1, topk_idx.to(torch.int64), 1).bool() - batch_size = routing_map.shape[0] - num_experts = routing_map.shape[1] - topk = topk_idx.shape[1] - row_indices = paddle.arange(0, batch_size, dtype=topk_idx.dtype).unsqueeze(1).expand([batch_size, topk]) - indices = paddle.stack([row_indices, topk_idx], axis=2).reshape([-1, 2]) - - tmp = paddle.zeros([batch_size, num_experts], dtype='float32') - ones = paddle.ones([indices.shape[0],], dtype='float32') - tmp = paddle.scatter_nd_add(tmp, indices, ones) - - routing_map = (tmp > 0).astype('bool') + num_of_tokens, num_of_experts, device="cuda", dtype=torch.uint8 + ) + routing_map = paddle.put_along_axis( + routing_map, + topk_idx, + torch.ones(topk_idx.shape, device="cuda", dtype=torch.uint8), + axis=1, + ).bool() if topk_weights is not None: probs = torch.zeros( - num_of_tokens, num_of_experts, dtype=torch.float32 - ).cuda() - updates = topk_weights.reshape([-1]) - tmp = paddle.zeros_like(probs) - tmp = paddle.scatter_nd_add(tmp, indices, updates) - probs = tmp + num_of_tokens, num_of_experts, device="cuda", dtype=torch.float32 + ) + probs = paddle.put_along_axis(probs, topk_idx, topk_weights, axis=1) else: probs = None return routing_map, probs