From 7c1dc429e0441bf9f952ff8233102509bb887568 Mon Sep 17 00:00:00 2001 From: Sam White Date: Thu, 31 Oct 2024 14:11:38 -0500 Subject: [PATCH] Add support for IMPI Thread-Split mode using OpenMP --- csrc/cpu/comm/ccl.cpp | 1 + csrc/cpu/comm/coll_mpi.cpp | 86 +++++++++++++++++++++++++++++++++----- csrc/cpu/comm/coll_mpi.hpp | 1 + 3 files changed, 77 insertions(+), 11 deletions(-) diff --git a/csrc/cpu/comm/ccl.cpp b/csrc/cpu/comm/ccl.cpp index 722c3fd51b13..101f555dbbfe 100644 --- a/csrc/cpu/comm/ccl.cpp +++ b/csrc/cpu/comm/ccl.cpp @@ -278,6 +278,7 @@ void inference_all_reduce(torch::Tensor& data, py::object op) switch (data.scalar_type()) { case c10::ScalarType::BFloat16: data_size = numel * 2; break; + case c10::ScalarType::Half: data_size = numel * 2; break; case c10::ScalarType::Float: data_size = numel * 4; break; default: data_type_fallback = true; } diff --git a/csrc/cpu/comm/coll_mpi.cpp b/csrc/cpu/comm/coll_mpi.cpp index 3122be013e99..8546ee4f4bea 100644 --- a/csrc/cpu/comm/coll_mpi.cpp +++ b/csrc/cpu/comm/coll_mpi.cpp @@ -1,18 +1,54 @@ +// 1 = handoff, 0 = thread-split +#define HANDOFF 0 + #include #include "coll_mpi.hpp" +#if !HANDOFF // THREAD_SPLIT +#include +std::vector thread_comm; +static bool thread_comm_inited = false; +#endif + void init_mpi(void) { - MPI_Init(NULL, NULL); + int mpi_inited; + MPI_Initialized(&mpi_inited); + if (!mpi_inited) { +#if HANDOFF + MPI_Init(NULL, NULL); +#else // THREAD_SPLIT + int provided; + MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &provided); +#endif + } + init_mpi_thread_comms(); +} - //int size, rank; - //MPI_Comm_size(MPI_COMM_WORLD, &size); - //MPI_Comm_rank(MPI_COMM_WORLD, &rank); +void init_mpi_thread_comms(void) +{ +#if !HANDOFF + if (!thread_comm_inited) { + MPI_Info info; + char s[16]; + int num_threads = omp_get_max_threads(); + thread_comm.resize(num_threads); + for (int i = 0; i < num_threads; i++) { + MPI_Comm_dup(MPI_COMM_WORLD, &thread_comm[i]); + snprintf(s, 16, "%d", i); + MPI_Info_create(&info); + MPI_Info_set(info, "thread_id", s); + MPI_Comm_set_info(thread_comm[i], info); + MPI_Info_free(&info); + } + thread_comm_inited = true; + } +#endif } +/* char temp_buf[64*1024*1024]; -/* void naive_all_reduce(int world_size, int rank, void* buf, size_t data_size, size_t numel, c10::ScalarType scalar_type) { if (rank == 0) { @@ -155,19 +191,47 @@ void ring_all_reduce(int world_size, int rank, void* buf, size_t data_size, size void mpi_all_reduce(int world_size, int rank, void* buf, size_t data_size, size_t numel, c10::ScalarType scalar_type) { +#if HANDOFF switch (scalar_type) { case c10::ScalarType::BFloat16: - //naive_all_reduce(world_size, rank, buf, data_size, numel, scalar_type); - //ring_all_reduce(world_size, rank, buf, data_size, numel, scalar_type); - //rabenseifner_all_reduce(world_size, rank, buf, data_size, numel, scalar_type); + MPI_Allreduce(MPI_IN_PLACE, buf, numel, MPIX_C_BF16, MPI_SUM, MPI_COMM_WORLD); + break; + case c10::ScalarType::Half: + MPI_Allreduce(MPI_IN_PLACE, buf, numel, MPIX_C_FLOAT16, MPI_SUM, MPI_COMM_WORLD); break; case c10::ScalarType::Float: MPI_Allreduce(MPI_IN_PLACE, buf, numel, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD); - //naive_all_reduce(world_size, rank, buf, data_size, numel, scalar_type); - //ring_all_reduce(world_size, rank, buf, data_size, numel, scalar_type); - //rabenseifner_all_reduce(world_size, rank, buf, data_size, numel, scalar_type); break; default: assert(!"Should not get here"); } + +#else // THREAD_SPLIT + + // Could tune number of threads for performance based on numel... + int nthds = std::min((size_t)omp_get_max_threads(), numel); + + #pragma omp parallel for num_threads(nthds) schedule(static) shared(nthds, numel) + for (int tid = 0; tid < nthds; tid++) + { + size_t my_numel = numel / nthds; + char *my_buf = (char *)buf + (tid * my_numel); + if (tid == nthds - 1) { // Last thread may have uneven number of elements + my_numel = numel - (my_numel * (nthds - 1)); // Could balance better... + } + + switch (scalar_type) { + case c10::ScalarType::BFloat16: + MPI_Allreduce(MPI_IN_PLACE, my_buf, my_numel, MPIX_C_BF16, MPI_SUM, thread_comm[tid]); + break; + case c10::ScalarType::Half: + MPI_Allreduce(MPI_IN_PLACE, my_buf, my_numel, MPIX_C_FLOAT16, MPI_SUM, thread_comm[tid]); + break; + case c10::ScalarType::Float: + MPI_Allreduce(MPI_IN_PLACE, my_buf, my_numel, MPI_FLOAT, MPI_SUM, thread_comm[tid]); + break; + default: assert(!"Should not get here"); + } + } // omp parallel for +#endif // THREAD_SPLIT } diff --git a/csrc/cpu/comm/coll_mpi.hpp b/csrc/cpu/comm/coll_mpi.hpp index 3ff50f8bbc09..2524cbc3eea2 100644 --- a/csrc/cpu/comm/coll_mpi.hpp +++ b/csrc/cpu/comm/coll_mpi.hpp @@ -4,6 +4,7 @@ #include void init_mpi(void); +void init_mpi_thread_comms(void); void mpi_all_reduce(int world_size, int rank, void* buf, size_t data_size, size_t numel, c10::ScalarType scalar_type); #endif //_COLL_MPI__HPP_