From 41e0b957613c550d559ec4aa24ed5d0899452b90 Mon Sep 17 00:00:00 2001 From: GordonYang1 <146812179@qq.com> Date: Mon, 18 May 2026 07:42:01 +0000 Subject: [PATCH] feat: support send and receive --- examples/send_recv.cc | 432 ++++++++++++++++++++++++++++++++++++++++++ include/comm.h | 10 +- src/base/recv.h | 61 ++++++ src/base/send.h | 62 ++++++ src/ompi/impl/recv.h | 72 +++++++ src/ompi/impl/send.h | 74 ++++++++ 6 files changed, 710 insertions(+), 1 deletion(-) create mode 100644 examples/send_recv.cc create mode 100644 src/base/recv.h create mode 100644 src/base/send.h create mode 100644 src/ompi/impl/recv.h create mode 100644 src/ompi/impl/send.h diff --git a/examples/send_recv.cc b/examples/send_recv.cc new file mode 100644 index 0000000..f686614 --- /dev/null +++ b/examples/send_recv.cc @@ -0,0 +1,432 @@ +/** + * InfiniCCL Example/Test: Point-to-Point Send/Recv. + * + * Runs a small suite of cases covering blocking P2P: + * 1. count=0 blocking ping (rank 0 -> 1) + * 2. blocking ping, rank 0 -> 1 + * 3. blocking ping, rank 0 -> size-1 + * 4. blocking ping-pong, rank 0 <-> 1 + * 5. large count (>INT_MAX bytes), gated by INFINI_SENDRECV_LARGE=1 + * 6. invalid peer (-1 and size) -> infiniInvalidArgument + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "backend_manifest.h" +#include "device.h" +#include "infiniccl.h" +#include "runtime.h" +#include "traits.h" +#include "utils.h" + +using namespace infini::ccl; + +namespace { + +struct CaseResult { + bool ok = true; + bool skipped = false; + std::string note; +}; + +bool AllRanksOk(bool local_ok) { + int local = local_ok ? 1 : 0; + int global = 0; + MPI_Allreduce(&local, &global, 1, MPI_INT, MPI_LAND, MPI_COMM_WORLD); + return global != 0; +} + +void PrintCase(int rank, const std::string &name, const CaseResult &local, + bool global_ok) { + if (rank != 0) { + return; + } + const char *GREEN = "\033[32m"; + const char *YELLOW = "\033[33m"; + const char *RED = "\033[31m"; + const char *RESET = "\033[0m"; + + std::string status; + if (local.skipped) { + status = std::string(YELLOW) + "SKIP" + RESET; + } else if (global_ok) { + status = std::string(GREEN) + "PASS" + RESET; + } else { + status = std::string(RED) + "FAIL" + RESET; + } + + std::cout << "[" << name << "] " << status; + if (!local.note.empty()) { + std::cout << " (rank0: " << local.note << ")"; + } + std::cout << std::endl; +} + +CaseResult SkipNeed2Ranks(int rank) { + return {true, true, (rank == 0) ? "needs at least 2 ranks" : ""}; +} + +int GetEnvInt(const char *name, int fallback) { + const char *value = std::getenv(name); + if (value == nullptr || value[0] == '\0') { + return fallback; + } + + char *end = nullptr; + long parsed = std::strtol(value, &end, 10); + if (end == value || *end != '\0') { + return fallback; + } + + return static_cast(parsed); +} + +int GetLocalRank() { + int local_rank = GetEnvInt("OMPI_COMM_WORLD_LOCAL_RANK", -1); + if (local_rank >= 0) { + return local_rank; + } + + local_rank = GetEnvInt("MPI_LOCALRANKID", -1); + if (local_rank >= 0) { + return local_rank; + } + + local_rank = GetEnvInt("SLURM_LOCALID", -1); + if (local_rank >= 0) { + return local_rank; + } + + return 0; +} + +template +void PrintRankMapping(int rank, int size) { + for (int r = 0; r < size; ++r) { + MPI_Barrier(MPI_COMM_WORLD); + + if (rank == r) { + char host[256] = {}; + if (gethostname(host, sizeof(host)) != 0) { + std::snprintf(host, sizeof(host), "unknown"); + } + + std::cout << "[Rank " << rank << "] Host: " << host + << " | GPU: " << Device::StringFromType(kDev) << " | Device " + << GetLocalRank() << std::endl; + } + } + + MPI_Barrier(MPI_COMM_WORLD); +} + +// Allocate a device buffer holding `count` floats == `value`. +template +float *AllocFilled(size_t count, float value) { + float *d = nullptr; + Runtime::Malloc(&d, count * sizeof(float)); + std::vector h(count, value); + Runtime::Memcpy(d, h.data(), count * sizeof(float), + Runtime::MemcpyHostToDevice); + return d; +} + +// Verify `count` floats on device equal `expected`. +template +bool DeviceEqualsFloat(const float *d, size_t count, float expected) { + std::vector h(count); + Runtime::Memcpy(h.data(), d, count * sizeof(float), + Runtime::MemcpyDeviceToHost); + for (size_t i = 0; i < count; ++i) { + if (std::fabs(h[i] - expected) > 1e-3) { + return false; + } + } + return true; +} + +// --------------------------------------------------------------------------- +// Case 1: count=0 blocking ping +// --------------------------------------------------------------------------- +CaseResult Case_Count0Ping(infiniComm_t comm, int rank, int size) { + if (size < 2) { + return SkipNeed2Ranks(rank); + } + if (rank == 0) { + infiniResult_t s = infiniSend(nullptr, 0, infiniFloat32, 1, comm, nullptr); + if (s != infiniSuccess) { + return {false, false, + "send returned " + std::to_string(static_cast(s))}; + } + } else if (rank == 1) { + infiniResult_t s = infiniRecv(nullptr, 0, infiniFloat32, 0, comm, nullptr); + if (s != infiniSuccess) { + return {false, false, + "recv returned " + std::to_string(static_cast(s))}; + } + } + return {}; +} + +// --------------------------------------------------------------------------- +// Helper: run a basic blocking ping `sender → receiver` for `count` floats +// of value `value`, and have the receiver verify on the device side. +// --------------------------------------------------------------------------- +template +CaseResult RunBlockingPing(infiniComm_t comm, int rank, int sender, + int receiver, size_t count, float value) { + if (rank == sender) { + float *d_send = AllocFilled(count, value); + infiniResult_t s = + infiniSend(d_send, count, infiniFloat32, receiver, comm, nullptr); + Runtime::Free(d_send); + if (s != infiniSuccess) { + return {false, false, + "send returned " + std::to_string(static_cast(s))}; + } + } else if (rank == receiver) { + float *d_recv = AllocFilled(count, -1.0f); + infiniResult_t s = + infiniRecv(d_recv, count, infiniFloat32, sender, comm, nullptr); + if (s != infiniSuccess) { + Runtime::Free(d_recv); + return {false, false, + "recv returned " + std::to_string(static_cast(s))}; + } + bool ok = DeviceEqualsFloat(d_recv, count, value); + Runtime::Free(d_recv); + if (!ok) { + return {false, false, "received data did not match expected value"}; + } + } + return {}; +} + +// --------------------------------------------------------------------------- +// Case 2/3: blocking ping rank 0 → 1, rank 0 → size-1 +// --------------------------------------------------------------------------- +template +CaseResult Case_BlockingPing01(infiniComm_t comm, int rank, int size) { + if (size < 2) return SkipNeed2Ranks(rank); + return RunBlockingPing(comm, rank, /*sender=*/0, /*receiver=*/1, + /*count=*/1024, /*value=*/3.5f); +} + +template +CaseResult Case_BlockingPing0Last(infiniComm_t comm, int rank, int size) { + if (size < 2) return SkipNeed2Ranks(rank); + return RunBlockingPing(comm, rank, /*sender=*/0, + /*receiver=*/size - 1, + /*count=*/2048, /*value=*/-7.25f); +} + +// --------------------------------------------------------------------------- +// Case 4: blocking ping-pong, rank 0 ↔ 1 +// --------------------------------------------------------------------------- +template +CaseResult Case_BlockingPingPong01(infiniComm_t comm, int rank, int size) { + if (size < 2) return SkipNeed2Ranks(rank); + constexpr size_t kCount = 512; + constexpr float kForward = 11.0f; + constexpr float kReply = -22.0f; + + if (rank == 0) { + float *d_out = AllocFilled(kCount, kForward); + float *d_in = AllocFilled(kCount, -1.0f); + + infiniResult_t s = + infiniSend(d_out, kCount, infiniFloat32, 1, comm, nullptr); + if (s != infiniSuccess) { + Runtime::Free(d_out); + Runtime::Free(d_in); + return {false, false, + "rank0 send returned " + std::to_string(static_cast(s))}; + } + s = infiniRecv(d_in, kCount, infiniFloat32, 1, comm, nullptr); + if (s != infiniSuccess) { + Runtime::Free(d_out); + Runtime::Free(d_in); + return {false, false, + "rank0 recv returned " + std::to_string(static_cast(s))}; + } + bool ok = DeviceEqualsFloat(d_in, kCount, kReply); + Runtime::Free(d_out); + Runtime::Free(d_in); + if (!ok) return {false, false, "rank0 reply mismatch"}; + } else if (rank == 1) { + float *d_in = AllocFilled(kCount, -1.0f); + infiniResult_t s = + infiniRecv(d_in, kCount, infiniFloat32, 0, comm, nullptr); + if (s != infiniSuccess) { + Runtime::Free(d_in); + return {false, false, + "rank1 recv returned " + std::to_string(static_cast(s))}; + } + bool ok = DeviceEqualsFloat(d_in, kCount, kForward); + Runtime::Free(d_in); + if (!ok) return {false, false, "rank1 forward mismatch"}; + + float *d_out = AllocFilled(kCount, kReply); + s = infiniSend(d_out, kCount, infiniFloat32, 0, comm, nullptr); + Runtime::Free(d_out); + if (s != infiniSuccess) { + return {false, false, + "rank1 send returned " + std::to_string(static_cast(s))}; + } + } + return {}; +} + +// --------------------------------------------------------------------------- +// Case 7: large count chunking (blocking, gated) +// --------------------------------------------------------------------------- +template +CaseResult Case_LargeCount(infiniComm_t comm, int rank, int size) { + if (size < 2) return SkipNeed2Ranks(rank); + if (std::getenv("INFINI_SENDRECV_LARGE") == nullptr) { + return { + true, true, + (rank == 0) ? "set INFINI_SENDRECV_LARGE=1 to enable (~2GB/rank)" : ""}; + } + const size_t count = static_cast(std::numeric_limits::max()) + + static_cast(1024); + const std::int8_t expected = 0x5A; + const size_t total_bytes = count * sizeof(std::int8_t); + + if (rank == 0) { + std::int8_t *d_out = nullptr; + Runtime::Malloc(&d_out, total_bytes); + std::vector h_out(count, expected); + Runtime::Memcpy(d_out, h_out.data(), total_bytes, + Runtime::MemcpyHostToDevice); + infiniResult_t s = infiniSend(d_out, count, infiniChar, 1, comm, nullptr); + Runtime::Free(d_out); + if (s != infiniSuccess) { + return {false, false, + "Send returned " + std::to_string(static_cast(s))}; + } + } else if (rank == 1) { + std::int8_t *d_in = nullptr; + Runtime::Malloc(&d_in, total_bytes); + infiniResult_t s = infiniRecv(d_in, count, infiniChar, 0, comm, nullptr); + if (s != infiniSuccess) { + Runtime::Free(d_in); + return {false, false, + "Recv returned " + std::to_string(static_cast(s))}; + } + std::int8_t probes[3] = {-1, -1, -1}; + Runtime::Memcpy(&probes[0], d_in, sizeof(std::int8_t), + Runtime::MemcpyDeviceToHost); + Runtime::Memcpy(&probes[1], d_in + (count / 2), sizeof(std::int8_t), + Runtime::MemcpyDeviceToHost); + Runtime::Memcpy(&probes[2], d_in + (count - 1), sizeof(std::int8_t), + Runtime::MemcpyDeviceToHost); + Runtime::Free(d_in); + if (probes[0] != expected || probes[1] != expected || + probes[2] != expected) { + return {false, false, "head/mid/tail mismatch"}; + } + } + return {}; +} + +// --------------------------------------------------------------------------- +// Case 8: invalid peer +// --------------------------------------------------------------------------- +CaseResult Case_InvalidPeer(infiniComm_t comm, int rank, int size) { + if (size < 2) return SkipNeed2Ranks(rank); + // Only rank 0 attempts the bad calls; the impl rejects them at the base + // validator before any MPI traffic, so other ranks don't need to mirror. + if (rank != 0) { + return {}; + } + float dummy = 0.f; + for (int bad_peer : {-1, size}) { + infiniResult_t s = + infiniSend(&dummy, 1, infiniFloat32, bad_peer, comm, nullptr); + if (s != infiniInvalidArgument) { + return {false, false, + "Send peer=" + std::to_string(bad_peer) + " expected " + + std::to_string(static_cast(infiniInvalidArgument)) + + ", got " + std::to_string(static_cast(s))}; + } + s = infiniRecv(&dummy, 1, infiniFloat32, bad_peer, comm, nullptr); + if (s != infiniInvalidArgument) { + return {false, false, + "Recv peer=" + std::to_string(bad_peer) + " expected " + + std::to_string(static_cast(infiniInvalidArgument)) + + ", got " + std::to_string(static_cast(s))}; + } + } + return {}; +} + +} // namespace + +int main(int argc, char **argv) { + constexpr Device::Type kDevType = + ListGetBest(EnabledDevices{}); + + CHECK_INFINI(infiniInit(&argc, &argv)); + + int rank = 0; + int size = 0; + CHECK_INFINI(infiniGetRank(&rank)); + CHECK_INFINI(infiniGetSize(&size)); + + if (rank == 0) { + std::cout << "=== Send/Recv Test Suite ===" << std::endl; + std::cout << "Device: " << Device::StringFromType(kDevType) << std::endl; + std::cout << "Ranks: " << size << std::endl; + } + + PrintRankMapping(rank, size); + + infiniComm_t comm = nullptr; + CHECK_INFINI(infiniCommInitAll(&comm, size, nullptr)); + + bool overall_ok = true; + + auto run = [&](const std::string &name, CaseResult local) { + bool global_ok = AllRanksOk(local.ok); + PrintCase(rank, name, local, global_ok); + if (!local.skipped) { + overall_ok = overall_ok && global_ok; + } + }; + + run("count=0 ping (blocking)", Case_Count0Ping(comm, rank, size)); + run("blocking ping, 0 -> 1", Case_BlockingPing01(comm, rank, size)); + run("blocking ping, 0 -> size-1", + Case_BlockingPing0Last(comm, rank, size)); + run("blocking ping-pong, 0 <-> 1", + Case_BlockingPingPong01(comm, rank, size)); + run("large count (>INT_MAX bytes)", + Case_LargeCount(comm, rank, size)); + run("invalid peer", Case_InvalidPeer(comm, rank, size)); + + if (rank == 0) { + const char *GREEN = "\033[32m"; + const char *RED = "\033[31m"; + const char *RESET = "\033[0m"; + std::cout << "\n=== Summary ===" << std::endl; + std::cout << (overall_ok ? (std::string(GREEN) + "ALL PASS" + RESET) + : (std::string(RED) + "FAILED" + RESET)) + << std::endl; + } + + CHECK_INFINI(infiniCommDestroy(comm)); + CHECK_INFINI(infiniFinalize()); + + return overall_ok ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/include/comm.h b/include/comm.h index 0029345..9e4c96c 100644 --- a/include/comm.h +++ b/include/comm.h @@ -45,8 +45,16 @@ infiniResult_t infiniAllGather(const void *sendbuff, void *recvbuff, size_t count, infiniDataType_t datatype, infiniComm_t comm, void *stream); +infiniResult_t infiniSend(const void *sendbuff, size_t count, + infiniDataType_t datatype, int peer, + infiniComm_t comm, void *stream); + +infiniResult_t infiniRecv(void *recvbuff, size_t count, + infiniDataType_t datatype, int peer, + infiniComm_t comm, void *stream); + #ifdef __cplusplus } #endif -#endif // INFINI_CCL_COMM_H_ +#endif // INFINI_CCL_COMM_H_ diff --git a/src/base/recv.h b/src/base/recv.h new file mode 100644 index 0000000..d46570a --- /dev/null +++ b/src/base/recv.h @@ -0,0 +1,61 @@ +#ifndef INFINI_CCL_BASE_RECV_H_ +#define INFINI_CCL_BASE_RECV_H_ + +#include "communicator.h" +#include "data_type_impl.h" +#include "logging.h" +#include "operation.h" +#include "return_status_impl.h" + +namespace infini::ccl { + +template +struct RecvImpl; + +class Recv : public Operation { + public: + template + static ReturnStatus Execute(void *recv_buff, size_t count, DataType datatype, + int peer, void *comm_handle, void *stream) { + if (!comm_handle) { + LOG("Invalid communicator handle for Recv."); + return ReturnStatus::kInvalidArgument; + } + + auto *comm = static_cast(comm_handle); + if (HasInvalidArgs(recv_buff, count, datatype, peer, comm)) { + return ReturnStatus::kInvalidArgument; + } + if (count == 0) { + return ReturnStatus::kSuccess; + } + + return RecvImpl::Apply( + recv_buff, count, datatype, peer, comm, stream); + } + + private: + static bool HasInvalidArgs(const void *recv_buff, size_t count, + DataType datatype, int peer, Communicator *comm) { + if (datatype < DataType::kChar || datatype >= DataType::kNumTypes) { + LOG("Invalid data type for Recv."); + return true; + } + if (peer < 0 || peer >= comm->size()) { + LOG("Invalid peer rank for Recv."); + return true; + } + if (count == 0) { + return false; + } + if (!recv_buff) { + LOG("Invalid receive buffer pointer for Recv."); + return true; + } + return false; + } +}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_BASE_RECV_H_ diff --git a/src/base/send.h b/src/base/send.h new file mode 100644 index 0000000..d911419 --- /dev/null +++ b/src/base/send.h @@ -0,0 +1,62 @@ +#ifndef INFINI_CCL_BASE_SEND_H_ +#define INFINI_CCL_BASE_SEND_H_ + +#include "communicator.h" +#include "data_type_impl.h" +#include "logging.h" +#include "operation.h" +#include "return_status_impl.h" + +namespace infini::ccl { + +template +struct SendImpl; + +class Send : public Operation { + public: + template + static ReturnStatus Execute(const void *send_buff, size_t count, + DataType datatype, int peer, void *comm_handle, + void *stream) { + if (!comm_handle) { + LOG("Invalid communicator handle for Send."); + return ReturnStatus::kInvalidArgument; + } + + auto *comm = static_cast(comm_handle); + if (HasInvalidArgs(send_buff, count, datatype, peer, comm)) { + return ReturnStatus::kInvalidArgument; + } + if (count == 0) { + return ReturnStatus::kSuccess; + } + + return SendImpl::Apply( + send_buff, count, datatype, peer, comm, stream); + } + + private: + static bool HasInvalidArgs(const void *send_buff, size_t count, + DataType datatype, int peer, Communicator *comm) { + if (datatype < DataType::kChar || datatype >= DataType::kNumTypes) { + LOG("Invalid data type for Send."); + return true; + } + if (peer < 0 || peer >= comm->size()) { + LOG("Invalid peer rank for Send."); + return true; + } + if (count == 0) { + return false; + } + if (!send_buff) { + LOG("Invalid send buffer pointer for Send."); + return true; + } + return false; + } +}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_BASE_SEND_H_ diff --git a/src/ompi/impl/recv.h b/src/ompi/impl/recv.h new file mode 100644 index 0000000..1ed02a4 --- /dev/null +++ b/src/ompi/impl/recv.h @@ -0,0 +1,72 @@ +#ifndef INFINI_CCL_OMPI_IMPL_RECV_H_ +#define INFINI_CCL_OMPI_IMPL_RECV_H_ + +#include +#include + +#include "base/recv.h" +#include "communicator.h" +#include "data_type_impl.h" +#include "logging.h" +#include "ompi/checks.h" +#include "ompi/comm_instance.h" + +namespace infini::ccl { + +template +class RecvImpl { + public: + static ReturnStatus Apply(void *recv_buff, size_t count, DataType data_type, + int peer, Communicator *comm, void *stream) { + constexpr Device::Type kDev = + ListGetBest(ActiveDevices{}); + + auto *inst = static_cast(comm->inter_comm()); + if (!inst || inst->handle == MPI_COMM_NULL) { + LOG("Invalid OpenMPI communicator instance for Recv."); + return ReturnStatus::kInternalError; + } + + size_t type_size = kDataTypeToSize.at(data_type); + if (count > std::numeric_limits::max() / type_size) { + LOG("Recv byte size overflow."); + return ReturnStatus::kInvalidArgument; + } + + size_t total_bytes = count * type_size; + void *host_buf = std::malloc(total_bytes); + if (!host_buf) { + LOG("Failed to allocate host buffer for Recv staging."); + return ReturnStatus::kSystemError; + } + + auto *bytes = static_cast(host_buf); + size_t offset = 0; + constexpr size_t kMaxMpiCount = + static_cast(std::numeric_limits::max()); + constexpr int kTag = 0; + while (offset < total_bytes) { + size_t chunk = total_bytes - offset; + if (chunk > kMaxMpiCount) { + chunk = kMaxMpiCount; + } + INFINI_CHECK_MPI(MPI_Recv(bytes + offset, static_cast(chunk), + MPI_BYTE, peer, kTag, inst->handle, + MPI_STATUS_IGNORE)); + offset += chunk; + } + + Runtime::Memcpy(recv_buff, host_buf, total_bytes, + Runtime::MemcpyHostToDevice); + + std::free(host_buf); + return ReturnStatus::kSuccess; + } +}; + +template <> +struct BackendEnabled : std::true_type {}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_OMPI_IMPL_RECV_H_ diff --git a/src/ompi/impl/send.h b/src/ompi/impl/send.h new file mode 100644 index 0000000..211b552 --- /dev/null +++ b/src/ompi/impl/send.h @@ -0,0 +1,74 @@ +#ifndef INFINI_CCL_OMPI_IMPL_SEND_H_ +#define INFINI_CCL_OMPI_IMPL_SEND_H_ + +#include +#include + +#include "base/send.h" +#include "communicator.h" +#include "data_type_impl.h" +#include "logging.h" +#include "ompi/checks.h" +#include "ompi/comm_instance.h" + +namespace infini::ccl { + +template +class SendImpl { + public: + static ReturnStatus Apply(const void *send_buff, size_t count, + DataType data_type, int peer, Communicator *comm, + void *stream) { + constexpr Device::Type kDev = + ListGetBest(ActiveDevices{}); + + auto *inst = static_cast(comm->inter_comm()); + if (!inst || inst->handle == MPI_COMM_NULL) { + LOG("Invalid OpenMPI communicator instance for Send."); + return ReturnStatus::kInternalError; + } + + size_t type_size = kDataTypeToSize.at(data_type); + if (count > std::numeric_limits::max() / type_size) { + LOG("Send byte size overflow."); + return ReturnStatus::kInvalidArgument; + } + + size_t total_bytes = count * type_size; + void *host_buf = std::malloc(total_bytes); + if (!host_buf) { + LOG("Failed to allocate host buffer for Send staging."); + return ReturnStatus::kSystemError; + } + + Runtime::Memcpy(host_buf, send_buff, total_bytes, + Runtime::MemcpyDeviceToHost); + Runtime::StreamSynchronize( + static_cast::Stream>(stream)); + + auto *bytes = static_cast(host_buf); + size_t offset = 0; + constexpr size_t kMaxMpiCount = + static_cast(std::numeric_limits::max()); + constexpr int kTag = 0; + while (offset < total_bytes) { + size_t chunk = total_bytes - offset; + if (chunk > kMaxMpiCount) { + chunk = kMaxMpiCount; + } + INFINI_CHECK_MPI(MPI_Send(bytes + offset, static_cast(chunk), + MPI_BYTE, peer, kTag, inst->handle)); + offset += chunk; + } + + std::free(host_buf); + return ReturnStatus::kSuccess; + } +}; + +template <> +struct BackendEnabled : std::true_type {}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_OMPI_IMPL_SEND_H_