From 177b2ecf75fc89a2f3ba8a8481a6d9e2d9ee3267 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 2 Dec 2025 19:44:39 +0000 Subject: [PATCH 1/9] cuBlasMp backend logic added to TE/common with connections to framework extensions Signed-off-by: Alp Dener --- build_tools/pytorch.py | 6 + .../comm_gemm_overlap/comm_gemm_overlap.cpp | 177 ++++++++++++++---- .../transformer_engine/comm_gemm_overlap.h | 32 ++++ .../common/util/pybind_helper.h | 4 + transformer_engine/jax/cpp_extensions/gemm.py | 26 ++- .../jax/csrc/extensions/cgemm_helper.cpp | 28 ++- .../jax/csrc/extensions/cgemm_helper.h | 3 +- .../jax/csrc/extensions/gemm.cpp | 15 +- transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions.h | 15 ++ .../pytorch/csrc/extensions/pybind.cpp | 6 + transformer_engine/pytorch/module/base.py | 89 ++++++--- 12 files changed, 311 insertions(+), 91 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index b03ef04fa42..82df324cd7d 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -96,6 +96,12 @@ def setup_pytorch_extension( libraries.append("nvshmem_host") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") + if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))): + # Creating a cuBlasMp context requires direct access to the underlying NCCL + # communicator in a tensor-parallel process group. The header for ProcessGroupNCCL + # needs this CPP directive to be included properly. + cxx_flags.append("-DUSE_C10D_NCCL") + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 56369db27fa..8dab9492c00 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -23,6 +23,14 @@ using namespace std::placeholders; +bool nvte_built_with_cublasmp() { +#ifdef NVTE_WITH_CUBLASMP + return true; +#else + return false; +#endif +} + namespace transformer_engine { namespace { @@ -33,10 +41,6 @@ std::vector shape_to_vector(const NVTEShape &shape) { } // namespace -/*************************************************************************************************** - * Comm+GEMM Overlap Common Core - **************************************************************************************************/ - bool ubuf_built_with_mpi() { #ifdef NVTE_UB_WITH_MPI return true; @@ -45,6 +49,10 @@ bool ubuf_built_with_mpi() { #endif } +/*************************************************************************************************** + * Comm+GEMM Overlap Common Core + **************************************************************************************************/ + CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, @@ -69,6 +77,30 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl num_comm_sm, set_sm_margin, use_ce, atomic_gemm); } +CommOverlapCore::CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size, + int num_comm_sm, bool is_p2p, bool atomic_gemm) { + _with_cublasmp = true; + + nvte_comm_gemm_ctx_create(reinterpret_cast(nccl_comm_ptr), tp_size, tp_rank); + + _num_comm_sm = num_comm_sm; + _is_p2p = is_p2p; + _atomic_gemm = atomic_gemm; + if (_is_p2p) { + if (_atomic_gemm) { + _algo_type = kNVTECommGemmAlgoAtomicP2P; + } else { + _algo_type = kNVTECommGemmAlgoSplitP2P; + } + } else { + if (_atomic_gemm) { + _algo_type = kNVTECommGemmAlgoAtomicMulticast; + } else { + _algo_type = kNVTECommGemmAlgoSplitMulticast; + } + } +} + void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, @@ -134,39 +166,43 @@ void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_stream } CommOverlapCore::~CommOverlapCore() { - cudaEventDestroy(_stop_comm); - cudaEventDestroy(_start_comm); - cudaEventDestroy(_stop_compute); - cudaEventDestroy(_start_compute); - if (_comm_launch_event) { - cudaEventDestroy(_comm_launch_event); - } + if (_with_cublasmp) { + nvte_comm_gemm_ctx_destroy(_cublasmp_ctx); + } else { + cudaEventDestroy(_stop_comm); + cudaEventDestroy(_start_comm); + cudaEventDestroy(_stop_compute); + cudaEventDestroy(_start_compute); + if (_comm_launch_event) { + cudaEventDestroy(_comm_launch_event); + } - if (_atomic_gemm) { - cudaFree(_counter.dptr()); - } + if (_atomic_gemm) { + cudaFree(_counter.dptr()); + } - for (size_t i = 0; i < _stream_compute.size(); i++) { - cudaStreamSynchronize(_stream_compute[i]); - cudaStreamDestroy(_stream_compute[i]); - } + for (size_t i = 0; i < _stream_compute.size(); i++) { + cudaStreamSynchronize(_stream_compute[i]); + cudaStreamDestroy(_stream_compute[i]); + } - auto error = cudaGetLastError(); - if (error != cudaSuccess) { - NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); - } + auto error = cudaGetLastError(); + if (error != cudaSuccess) { + NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); + } - if (_comm_created) { - try { -#ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); -#else - destroy_communicator(_ub_comm); -#endif - } catch (const std::exception &e) { - NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); + if (_comm_created) { + try { + #ifdef NVTE_UB_WITH_MPI + destroy_communicator_mpi(_ub_comm); + #else + destroy_communicator(_ub_comm); + #endif + } catch (const std::exception &e) { + NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); + } + _comm_created = false; } - _comm_created = false; } } @@ -272,6 +308,34 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source return chunk; } +void CommOverlapCore::cublasmp_ag_gemm( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, bool grad, bool accumulate, + cudaStream_t stream_main) { + int64_t m = transa ? A.size(0) : A.size(1); + int64_t n_local = transb ? B.size(1) : B.size(0); + int64_t n = n_local * _tp_size; + int64_t k = transa ? A.size(1) : A.size(0); + + nvte_all_gather_gemm(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, + stream_main, _algo_type); +} + +void CommOverlapCore::cublasmp_gemm_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, bool grad, bool accumulate, + cudaStream_t stream_main) { + int64_t m = transa ? A.size(0) : A.size(1); + int64_t n = transb ? B.size(1) : B.size(0); + int64_t k_local = transa ? A.size(1) : A.size(0); + int64_t k = k * _tp_size; + + nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, + stream_main, _algo_type); +} + /*************************************************************************************************** * Comm+GEMM Overlap Base (Pipelined / Collective) **************************************************************************************************/ @@ -313,9 +377,11 @@ void CommOverlapBase::initialize(const std::vector &buffer_shape, DType } CommOverlapBase::~CommOverlapBase() { - cudaEventDestroy(_start_d2dcopy); - cudaStreamSynchronize(_stream_comm); - cudaStreamDestroy(_stream_comm); + if (!_with_cublasmp) { + cudaEventDestroy(_start_d2dcopy); + cudaStreamSynchronize(_stream_comm); + cudaStreamDestroy(_stream_comm); + } } /* @@ -328,6 +394,8 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_CHECK(!_with_cublasmp, "Bulk overlap is not supported with cuBlasMp"); + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -385,10 +453,15 @@ void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions size_t m = transa ? A.size(0) : A.size(1); size_t k = transa ? A.size(1) : A.size(0); @@ -481,6 +554,10 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + } + // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -724,16 +801,20 @@ void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DTy } CommOverlapP2PBase::~CommOverlapP2PBase() { - cudaEventDestroy(_stop_recv); - cudaEventDestroy(_stop_send); - cudaStreamDestroy(_stream_recv); - for (size_t i = 0; i < _stream_send.size(); i++) { - cudaStreamDestroy(_stream_send[i]); + if (!_with_cublasmp) { + cudaEventDestroy(_stop_recv); + cudaEventDestroy(_stop_send); + cudaStreamDestroy(_stream_recv); + for (size_t i = 0; i < _stream_send.size(); i++) { + cudaStreamDestroy(_stream_send[i]); + } } } void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, bool rowwise) { + if (_with_cublasmp) return; // cuBlasMp executes its own copy-into-buffer op + // Check element size const size_t element_size = source.element_size(); NVTE_CHECK(_ubuf.element_size() == element_size, @@ -788,6 +869,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -890,6 +975,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -1057,6 +1146,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -1121,6 +1214,10 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index cffc411a0d3..73abe45dcc2 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -17,6 +18,12 @@ #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 +/* \brief Check if TE is built with cuBlasMp. + * + * \return True if TE is built with cuBlasMp. + */ +bool nvte_built_with_cublasmp(); + namespace transformer_engine { /* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. @@ -59,6 +66,10 @@ class CommOverlapCore { bool _atomic_gemm{false}; bool _is_p2p{false}; + bool _with_cublasmp{false}; + NVTECommGemmCtx *_cublasmp_ctx{nullptr}; + NVTECommGemmAlgoType _algo_type = kNVTECommGemmAlgoDefault; + TensorWrapper _ubuf; TensorWrapper _counter; float *_ubuf_scale_inv; @@ -81,6 +92,9 @@ class CommOverlapCore { int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm); + CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm, + bool is_p2p, bool atomic_gemm); + virtual ~CommOverlapCore(); void *get_ubuf_dptr() { return _ubuf.dptr(); } @@ -109,6 +123,16 @@ class CommOverlapCore { bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + bool with_cublasmp() { return _with_cublasmp; } + + void cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + bool grad, bool accumulate, cudaStream_t stream_main); + + void cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + bool grad, bool accumulate, cudaStream_t stream_main); + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, @@ -177,6 +201,10 @@ class CommOverlapBase : public CommOverlapCore { bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); + CommOverlapBase(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm = 16, + bool atomic_gemm = false) + : CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size, num_comm_sm, false, atomic_gemm) {} + virtual ~CommOverlapBase(); /* @@ -257,6 +285,10 @@ class CommOverlapP2PBase : public CommOverlapCore { int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); + CommOverlapP2PBase(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm = 1, + bool atomic_gemm = false) + : CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size, num_comm_sm, true, atomic_gemm) {} + virtual ~CommOverlapP2PBase(); void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index bce124e705f..34e664ccbf0 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -112,6 +112,8 @@ .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ py::call_guard()) \ .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ + py::call_guard()) \ + .def("with_cublasmp", &transformer_engine::CommOverlapCore::with_cublasmp, \ py::call_guard()); \ py::class_, \ @@ -135,6 +137,8 @@ }, \ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); \ + m.def("nvte_built_with_cublasmp", &nvte_built_with_cublasmp, \ py::call_guard()); #endif diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 76a8b225ba8..37c6b01a9d0 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -386,7 +386,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19) inner_primitive = None outer_primitive = None @@ -411,6 +411,7 @@ def abstract( sequence_dim, is_outer, collective_op, + use_cublasmp, ): del use_split_accumulator, transpose_batch_sequence @@ -538,7 +539,11 @@ def _dims_are_consecutive(dims): if scaling_mode.is_nvfp4_scaling: workspace_size += lhs_scale_inv.size + rhs_scale_inv.size if not collective_op.is_none: - workspace_size *= get_cgemm_num_max_streams() + if use_cublasmp: + # cuBlasMp manages its own cuBlasLt workspaces per stream + workspace_size = 0 + else: + workspace_size *= get_cgemm_num_max_streams() # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. workspace_size += 256 @@ -573,6 +578,7 @@ def lowering( sequence_dim, is_outer, collective_op, + use_cublasmp, ): del out_dtype, transpose_batch_sequence, sequence_dim, is_outer @@ -617,6 +623,7 @@ def lowering( "grad": grad, "use_split_accumulator": use_split_accumulator, "collective_op": int(collective_op.value), + "use_cublasmp": use_cublasmp, } operand_output_aliases = {} @@ -651,6 +658,7 @@ def impl( sequence_dim, is_outer, collective_op, + use_cublasmp, ): if scaling_mode.is_1d_block_scaling(): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) @@ -718,6 +726,7 @@ def impl( transpose_batch_sequence=transpose_batch_sequence, sequence_dim=sequence_dim, is_outer=is_outer, + use_cublasmp=use_cublasmp, ) # Alter output blocks for CGEMM AG if ( @@ -769,6 +778,7 @@ def outer_impl( sequence_dim, is_outer, collective_op, + use_cublasmp, ): return GemmPrimitive.impl( lhs, @@ -790,6 +800,7 @@ def outer_impl( sequence_dim, is_outer, collective_op, + use_cublasmp, ) @staticmethod @@ -803,10 +814,11 @@ def batcher( fuse_gelu, grad, use_split_accumulator, - collective_op, transpose_batch_sequence, sequence_dim, is_outer, + collective_op, + use_cublasmp, ): del transpose_batch_sequence, sequence_dim, is_outer assert GemmPrimitive.outer_primitive is not None @@ -840,6 +852,7 @@ def batcher( transpose_batch_sequence=transpose_batch_sequence, sequence_dim=sequence_dim, is_outer=is_outer, + use_cublasmp=use_cublasmp, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -1002,6 +1015,7 @@ def infer_sharding_from_operands( sequence_dim, is_outer, collective_op, + use_cublasmp, mesh, arg_infos, result_infos, @@ -1013,6 +1027,7 @@ def infer_sharding_from_operands( result_infos, is_outer, sequence_dim, + use_cublasmp, ) (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( @@ -1047,6 +1062,7 @@ def partition( sequence_dim, is_outer, collective_op, + use_cublasmp, mesh, arg_infos, result_infos, @@ -1125,6 +1141,7 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alph sequence_dim=inferred_sequence_dim, is_outer=False, collective_op=collective_op, + use_cublasmp=use_cublasmp, ) if reduce_spec is not None: @@ -1156,6 +1173,7 @@ def shardy_sharding_rule( sequence_dim, is_outer, collective_op, + use_cublasmp, mesh, operand_types, result_types, @@ -1250,6 +1268,7 @@ def _te_gemm( use_split_accumulator: bool = None, transpose_batch_sequence: bool = False, collective_op: CollectiveOp = CollectiveOp.NONE, + use_cublasmp: bool = False, ) -> Tuple[jax.Array, ...]: if grad or fuse_gelu: @@ -1353,6 +1372,7 @@ def _te_gemm( sequence_dim=-1, # Dummy value and will be set in the primitive is_outer=True, collective_op=collective_op, + use_cublasmp=use_cublasmp, ) diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp index 7082bfb035b..1ee47e46d30 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -62,7 +62,7 @@ ncclUniqueId CommunicatorHandler::coordinate_nccl_unique_id(const std::string &i } void CommunicatorHandler::init(int num_total_devices, int num_devices_per_process, int process_id, - int tp_size) { + int tp_size, bool use_cublasmp) { // Validate inputs NVTE_CHECK(num_devices_per_process == 1, "num_devices_per_process must be == 1, got num_devices_per_process=", @@ -159,7 +159,8 @@ int GetCgemmNumMaxStreams() { CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector buffer_shape, DType dtype, - JAXX_Collective_Op collective_op) { + JAXX_Collective_Op collective_op, + bool use_cublasmp) { auto &comm_handler = CommunicatorHandler::get(); auto &cgemm_config = CgemmConfig::get(); @@ -192,14 +193,21 @@ CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector bu } std::unique_ptr executor; - executor = std::make_unique( - buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, - comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, - comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, - comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op), - cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority, - cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/, - cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag); + if (use_cublasmp) { + executor = std::make_unique( + reinterpret_cast(comm_handler.get_comm_for_current_device()), + comm_handler.tp_size, comm_handler.get_tp_domain_id(), + cgemm_config.num_comm_sm, cgemm_config.aggregate_ag); + } else { + executor = std::make_unique( + buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, + comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, + comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, + comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op), + cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority, + cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/, + cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag); + } CommOverlapCore *executor_ptr = executor.get(); plan_map[plan_id] = std::move(executor); diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h index 84b2b815401..cf25d4a0515 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.h +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -165,7 +165,8 @@ class CollectiveGemmPlanRegistry { } CommOverlapCore *get_executor(std::vector buffer_shape, DType dtype, - JAXX_Collective_Op collective_op); + JAXX_Collective_Op collective_op, + bool use_cublasmp = false); private: CollectiveGemmPlanRegistry() {} diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 6566ff1689e..9dfb806cf85 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -96,7 +96,8 @@ Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buf JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, - bool use_split_accumulator, JAXX_Collective_Op collective_op) { + bool use_split_accumulator, JAXX_Collective_Op collective_op, + bool use_cublasmp) { nvte_cublas_handle_init(); // Init UB buffer @@ -123,7 +124,7 @@ Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buf buffer_shape[1] = out_shape[1]; } auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype, - collective_op); + collective_op, use_cublasmp); } return ffi_with_cuda_error_check(); } @@ -151,7 +152,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, .Attr("fuse_gelu") .Attr("grad") .Attr("use_split_accumulator") - .Attr("collective_op")); + .Attr("collective_op") + .Attr("use_cublasmp")); Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, @@ -159,7 +161,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, - bool use_split_accumulator, JAXX_Collective_Op collective_op) { + bool use_split_accumulator, JAXX_Collective_Op collective_op, bool use_cublasmp) { // cuBLAS workspace + 256 alignment enforcement (+ swizzle scales) uint8_t *lhs_swizzle_scale_ptr = nullptr, *rhs_swizzle_scale_ptr = nullptr; auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); @@ -279,7 +281,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, ", out_shape[1]=", out_shape[1]); auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor( - buffer_shape, buffer_dtype, collective_op); + buffer_shape, buffer_dtype, collective_op, use_cublasmp); if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype); // Prepare the auxiliary buffer for the reduce-scattered GEMM output @@ -337,7 +339,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_gelu") .Attr("grad") .Attr("use_split_accumulator") - .Attr("collective_op"), + .Attr("collective_op") + .Attr("use_cublasmp"), FFI_CudaGraph_Traits); size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc1..b40280c4b88 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -49,6 +49,7 @@ #include #include #include +#include #include #include "c10/util/ArrayRef.h" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 80479dccf48..1c69f6758f9 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -526,6 +526,11 @@ class CommOverlapHelper : torch::CustomClassHolder { ExtComm comm); void ub_barrier(ExtComm comm); + + int64_t get_nccl_comm_ptr(std::string comm_name) { + NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL."); + return reinterpret_cast(pgs[comm_name])->getCommPtr(); + } }; class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { @@ -537,6 +542,11 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); + CommOverlap(CommOverlapHelper *helper, int tp_size, int tp_rank, int num_comm_sm = 16, + bool atomic_gemm = false) + : CommOverlapBase(helper->get_nccl_comm_ptr("intra"), tp_size, tp_rank, num_comm_sm, + atomic_gemm) {} + ~CommOverlap() {} void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); @@ -558,6 +568,11 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, bool aggregate = false); + CommOverlapP2P(CommOverlapHelper *helper, int tp_size, int tp_rank, int num_comm_sm = 16, + bool atomic_gemm = false) + : CommOverlapP2PBase(helper->get_nccl_comm_ptr("intra"), tp_size, tp_rank, num_comm_sm, + atomic_gemm) {} + ~CommOverlapP2P() {} void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d0f450bc712..000f892fc20 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -490,6 +490,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) + .def(py::init(), py::arg("helper"), + py::arg("tp_size"), py::arg("tp_rank"), py::arg("num_comm_sm") = 0, + py::arg("atomic_gemm") = false, py::call_guard()) .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false, @@ -508,6 +511,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) + .def(py::init(), py::arg("helper"), + py::arg("tp_size"), py::arg("tp_rank"), py::arg("num_comm_sm") = 0, + py::arg("atomic_gemm") = false, py::call_guard()) .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index acf92332817..457e798b2b3 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -100,6 +100,7 @@ def initialize_ub( dtype: torch.dtype = torch.bfloat16, ub_cfgs: Optional[Union[dict, List[dict]]] = None, bootstrap_backend: Union[str, torch.distributed.Backend] = None, + with_cublasmp: bool = False, ) -> None: r""" Initialize the Userbuffers communicator for overlapping tensor-parallel communications with @@ -152,6 +153,10 @@ def initialize_ub( not available. Setting ``NVTE_UB_WITH_MPI=1`` when building TE overrides this option and always initializes Userbuffers with direct MPI calls in C++, which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time. + with_cublasmp : bool = False + Whether to use cuBlasMp for the all-gather and reduce-scatter overlaps. TE must + be compiled with `NVTE_WITH_CUBLASMP=1` for this option to work. + """ if not tex.device_supports_multicast(): assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( @@ -326,6 +331,7 @@ def add_ub( comm_priority: int = 0, gemm_priority: int = 0, pipeline_rs_overlap_first_gemm: bool = False, + with_cublasmp: bool = False, ) -> None: if atomic_gemm: warnings.warn( @@ -379,38 +385,56 @@ def add_ub( else dtype ) if method == "ring_exchange": - ub_obj = tex.CommOverlapP2P( - shape, # Communication buffer shape - buffer_dtype, # Communication buffer data type - helper, # Helper for torch.distributed callbacks during bootstrapping - tp_size, # Tensor-parallel group size (may be different than local_size) - tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, - num_max_streams=_NUM_MAX_UB_STREAMS, - comm_cga_size=cga_size, - num_comm_sm=num_sm, - set_sm_margin=set_sm_margin, - atomic_gemm=atomic_gemm, - use_ce=use_ce, - aggregate=aggregate, - gemm_priority=gemm_priority, - comm_priority=comm_priority, - ) + if with_cublasmp: + ub_obj = tex.CommOverlapP2P( + helper, + tp_size, + local_rank, + num_comm_sm=num_sm, + atomic_gemm=atomic_gemm, + ) + else: + ub_obj = tex.CommOverlapP2P( + shape, # Communication buffer shape + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping + tp_size, # Tensor-parallel group size (may be different than local_size) + tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, + use_ce=use_ce, + aggregate=aggregate, + gemm_priority=gemm_priority, + comm_priority=comm_priority, + ) else: - ub_obj = tex.CommOverlap( - shape, # Communication buffer shape - buffer_dtype, # Communication buffer data type - helper, # Helper for torch.distributed callbacks during bootstrapping - tp_size, # Tensor-parallel group size (may be different than local_size) - num_splits=num_splits, - num_max_streams=_NUM_MAX_UB_STREAMS, - comm_cga_size=cga_size, - num_comm_sm=num_sm, - set_sm_margin=set_sm_margin, - atomic_gemm=atomic_gemm, - gemm_priority=gemm_priority, - comm_priority=comm_priority, - rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, - ) + if with_cublasmp and method != "bulk": + ub_obj = tex.CommOverlap( + helper, + tp_size, + local_rank, + num_comm_sm=num_sm, + atomic_gemm=atomic_gemm, + ) + else: + ub_obj = tex.CommOverlap( + shape, # Communication buffer shape + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping + tp_size, # Tensor-parallel group size (may be different than local_size) + num_splits=num_splits, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, + gemm_priority=gemm_priority, + comm_priority=comm_priority, + rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, + ) _ub_communicators[(name, quantization_mode)] = ub_obj for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs): @@ -480,6 +504,9 @@ def fill_userbuffers_buffer_for_all_gather( tensor's metadata, e.g. scaling factors. """ + # cuBlasMp already handles its own buffer filling and quantization factors + if comm.with_cublasmp(): + return # Tensor dimensions local_shape = local_tensor.size() From 7d46b0b7e07941627ecb396e8c7d3401c8992b8c Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 2 Dec 2025 20:12:43 +0000 Subject: [PATCH 2/9] added use_cublasmp flags to CollectiveGemm bootstrapping to avoid UB entirely Signed-off-by: Alp Dener --- transformer_engine/jax/csrc/extensions/cgemm_helper.cpp | 9 +++++---- transformer_engine/jax/csrc/extensions/cgemm_helper.h | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp index 1ee47e46d30..79b1d02def6 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -136,20 +136,21 @@ void CommunicatorHandler::init(int num_total_devices, int num_devices_per_proces handler._initialize = true; - // Bootstrap UB via creating a dummy CommOverlapP2PBase object + // Bootstrap UB/cuBlasMp via creating a dummy CommOverlapP2PBase object std::vector buffer_shape{1, 1}; auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32, - JAXX_Collective_Op::ALL_GATHER); + JAXX_Collective_Op::ALL_GATHER, + use_cublasmp); } void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, int tp_size, int num_max_streams, int gemm_priority, int comm_priority, int num_comm_sm, bool use_ce, - bool aggregate_ag) { + bool aggregate_ag, bool use_cublasmp) { auto &config = CgemmConfig::get(false); config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag); auto &handler = CommunicatorHandler::get(false); - handler.init(num_total_devices, num_devices_per_process, process_id, tp_size); + handler.init(num_total_devices, num_devices_per_process, process_id, tp_size, use_cublasmp); } int GetCgemmNumMaxStreams() { diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h index cf25d4a0515..473b5f626c2 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.h +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -127,7 +127,8 @@ class CommunicatorHandler { int get_tp_num_domains() const { return tp_num_domains; } - static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size); + static void init(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size, bool use_cublasmp = false); private: ncclUniqueId coordinate_nccl_unique_id(const std::string &id_type); @@ -180,7 +181,7 @@ class CollectiveGemmPlanRegistry { void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, int tp_size, int num_max_streams, int gemm_priority, int comm_priority, int num_comm_sm, bool use_ce, - bool aggregate_ag); + bool aggregate_ag, bool use_cublasmp = false); int GetCgemmNumMaxStreams(); From 6d4a1417f9be663b8438da4ca542baa4fc8b643f Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 16 Dec 2025 19:04:45 +0000 Subject: [PATCH 3/9] added cuBLASMp backend option to JAX unit tests for CollectiveGEMM Signed-off-by: Alp Dener --- examples/jax/collective_gemm/common.py | 7 ++ .../jax/collective_gemm/run_test_cgemm.sh | 84 ++++++++++--------- transformer_engine/common/CMakeLists.txt | 2 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 2 + .../transformer_engine/comm_gemm_overlap.h | 2 +- transformer_engine/jax/cpp_extensions/gemm.py | 26 ++---- transformer_engine/pytorch/csrc/extensions.h | 4 +- 7 files changed, 68 insertions(+), 59 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index da79b213772..3ef786efa94 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -154,6 +154,7 @@ def _initialize_distributed(args): num_devices_per_process=devices_per_process, process_id=args.process_id, tensor_parallel_size=args.tensor_parallel_size, + use_cublasmp=args.use_cublasmp, ) @@ -241,5 +242,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para parser.add_argument( "--enable-result-check", action="store_true", default=True, help="Enable result checking" ) + parser.add_argument( + "--use-cublasmp", + action="store_true", + default=False, + help="Use the cuBLASMp backend for overlapping collective operations with GEMM computation", + ) return parser diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index af263eb53d0..13b5daad8b8 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -65,50 +65,58 @@ for TEST_FILE in "${TEST_FILES[@]}"; do # Clear PIDs array for this test file PIDS=() - for i in $(seq 0 $(($NUM_GPUS - 1))); do - # Define output file for logs - LOG_FILE="${TEST_FILE}_gpu_${i}.log" - - if [ $i -eq 0 ]; then - # For process 0: show live output AND save to log file using tee - echo "=== Live output from process 0 ===" - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ - "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ - --num-processes=$NUM_GPUS \ - --process-id=$i 2>&1 | tee "$LOG_FILE" & - PID=$! - PIDS+=($PID) - else - # For other processes: redirect to log files only - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ - --num-processes=$NUM_GPUS \ - --process-id=$i > "$LOG_FILE" 2>&1 & - PID=$! - PIDS+=($PID) - fi + PYTEST_ARGS=( + "-vs" + "-c $TE_PATH/tests/jax/pytest.ini" + "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" + "--num-processes=$NUM_GPUS" + ) + + BACKENDS=("cublasmp" "userbuffers") + for backend in "${BACKENDS[@]}"; do + for i in $(seq 0 $(($NUM_GPUS - 1))); do + # Define output file for logs + LOG_FILE="${TEST_FILE}_gpu_${i}_${backend}.log" + + if [ $i -eq 0 ]; then + # For process 0: show live output AND save to log file using tee + echo "=== Live output from process 0 with ${backend} ===" + pytest --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ + "${PYTEST_ARGS[@]}" \ + --process-id=$i 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + # For other processes: redirect to log files only + pytest "${PYTEST_ARGS[@]}" \ + --process-id=$i > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done done # Wait for all processes to finish wait # Check and print the log content from process 0 (now has log file thanks to tee) - if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE SKIPPED" - elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE FAILED" - HAS_FAILURE=1 - elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE PASSED" - else - echo "... $TEST_FILE INVALID" - HAS_FAILURE=1 - fi - - # Remove the log files after processing them - wait - rm ${TEST_FILE}_gpu_*.log + for backend in "${BACKENDS[@]}"; do + if grep -q "SKIPPED" "${TEST_FILE}_gpu_0_${backend}.log"; then + echo "... $TEST_FILE SKIPPED for ${backend} backend" + elif grep -q "FAILED" "${TEST_FILE}_gpu_0_${backend}.log"; then + echo "... $TEST_FILE FAILED for ${backend} backend" + HAS_FAILURE=1 + elif grep -q "PASSED" "${TEST_FILE}_gpu_0_${backend}.log"; then + echo "... $TEST_FILE PASSED for ${backend} backend" + else + echo "... $TEST_FILE INVALID for ${backend} backend" + HAS_FAILURE=1 + fi + + # Remove the log files after processing them + wait + rm ${TEST_FILE}_gpu_*.log + done done wait diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 264f7f9a78d..ee3a90b8e3f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -271,7 +271,7 @@ if (NVTE_WITH_CUBLASMP) target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) find_library(CUBLASMP_LIB - NAMES cublasmp libcublasmp + NAMES cublasmp libcublasmp.so.0 PATHS ${CUBLASMP_DIR} PATH_SUFFIXES lib REQUIRED) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 8dab9492c00..5098b3d8ffb 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -79,6 +79,8 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl CommOverlapCore::CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm, bool is_p2p, bool atomic_gemm) { + NVTE_CHECK(nvte_built_with_cublasmp(), + "Comm+GEMM overlap with cuBLASMp backend requires TE to be built with NVTE_WITH_CUBLASMP=1."); _with_cublasmp = true; nvte_comm_gemm_ctx_create(reinterpret_cast(nccl_comm_ptr), tp_size, tp_rank); diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 73abe45dcc2..74fb9d3d96c 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -18,7 +18,7 @@ #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 -/* \brief Check if TE is built with cuBlasMp. +/* \brief Check if TE is built with cuBLASMp. * * \return True if TE is built with cuBlasMp. */ diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 37c6b01a9d0..e2f4b2377df 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -70,6 +70,7 @@ num_cublas_streams = get_num_compute_streams() +collective_gemm_with_cublasmp = False def get_cublas_workspace_size_bytes() -> None: @@ -198,6 +199,7 @@ def collective_gemm_bootstrap( num_sm_for_communication=2, use_ce=True, aggregate_all_gather=False, + use_cublasmp=False, ): """Initialize NCCL communicators for Collective GEMM operations. @@ -281,6 +283,8 @@ def collective_gemm_bootstrap( f" num_devices_per_process={num_devices_per_process}" ) assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}" + global collective_gemm_with_cublasmp + collective_gemm_with_cublasmp = use_cublasmp initialize_cgemm_communicator( num_total_devices, num_devices_per_process, @@ -292,6 +296,7 @@ def collective_gemm_bootstrap( num_sm_for_communication, use_ce, aggregate_all_gather, + use_cublasmp, ) @@ -386,7 +391,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19) + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) inner_primitive = None outer_primitive = None @@ -411,7 +416,6 @@ def abstract( sequence_dim, is_outer, collective_op, - use_cublasmp, ): del use_split_accumulator, transpose_batch_sequence @@ -539,7 +543,7 @@ def _dims_are_consecutive(dims): if scaling_mode.is_nvfp4_scaling: workspace_size += lhs_scale_inv.size + rhs_scale_inv.size if not collective_op.is_none: - if use_cublasmp: + if collective_gemm_with_cublasmp: # cuBlasMp manages its own cuBlasLt workspaces per stream workspace_size = 0 else: @@ -578,7 +582,6 @@ def lowering( sequence_dim, is_outer, collective_op, - use_cublasmp, ): del out_dtype, transpose_batch_sequence, sequence_dim, is_outer @@ -623,7 +626,7 @@ def lowering( "grad": grad, "use_split_accumulator": use_split_accumulator, "collective_op": int(collective_op.value), - "use_cublasmp": use_cublasmp, + "use_cublasmp": collective_gemm_with_cublasmp, } operand_output_aliases = {} @@ -658,7 +661,6 @@ def impl( sequence_dim, is_outer, collective_op, - use_cublasmp, ): if scaling_mode.is_1d_block_scaling(): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) @@ -726,7 +728,6 @@ def impl( transpose_batch_sequence=transpose_batch_sequence, sequence_dim=sequence_dim, is_outer=is_outer, - use_cublasmp=use_cublasmp, ) # Alter output blocks for CGEMM AG if ( @@ -778,7 +779,6 @@ def outer_impl( sequence_dim, is_outer, collective_op, - use_cublasmp, ): return GemmPrimitive.impl( lhs, @@ -800,7 +800,6 @@ def outer_impl( sequence_dim, is_outer, collective_op, - use_cublasmp, ) @staticmethod @@ -818,7 +817,6 @@ def batcher( sequence_dim, is_outer, collective_op, - use_cublasmp, ): del transpose_batch_sequence, sequence_dim, is_outer assert GemmPrimitive.outer_primitive is not None @@ -852,7 +850,6 @@ def batcher( transpose_batch_sequence=transpose_batch_sequence, sequence_dim=sequence_dim, is_outer=is_outer, - use_cublasmp=use_cublasmp, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -1015,7 +1012,6 @@ def infer_sharding_from_operands( sequence_dim, is_outer, collective_op, - use_cublasmp, mesh, arg_infos, result_infos, @@ -1027,7 +1023,6 @@ def infer_sharding_from_operands( result_infos, is_outer, sequence_dim, - use_cublasmp, ) (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( @@ -1062,7 +1057,6 @@ def partition( sequence_dim, is_outer, collective_op, - use_cublasmp, mesh, arg_infos, result_infos, @@ -1141,7 +1135,6 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alph sequence_dim=inferred_sequence_dim, is_outer=False, collective_op=collective_op, - use_cublasmp=use_cublasmp, ) if reduce_spec is not None: @@ -1173,7 +1166,6 @@ def shardy_sharding_rule( sequence_dim, is_outer, collective_op, - use_cublasmp, mesh, operand_types, result_types, @@ -1268,7 +1260,6 @@ def _te_gemm( use_split_accumulator: bool = None, transpose_batch_sequence: bool = False, collective_op: CollectiveOp = CollectiveOp.NONE, - use_cublasmp: bool = False, ) -> Tuple[jax.Array, ...]: if grad or fuse_gelu: @@ -1372,7 +1363,6 @@ def _te_gemm( sequence_dim=-1, # Dummy value and will be set in the primitive is_outer=True, collective_op=collective_op, - use_cublasmp=use_cublasmp, ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1c69f6758f9..3990a22ea07 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -528,7 +528,9 @@ class CommOverlapHelper : torch::CustomClassHolder { void ub_barrier(ExtComm comm); int64_t get_nccl_comm_ptr(std::string comm_name) { - NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL."); + NVTE_CHECK(backend_is_nccl, + "Comm+GEMM overlap with cuBLASMp backend requires a tensor-parallel process ", + "group with NCCL backend."); return reinterpret_cast(pgs[comm_name])->getCommPtr(); } }; From 35d0f197dd89bc135ce326a7a6a8a77e6557780b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Dec 2025 19:05:45 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 61 +++++++++++-------- .../transformer_engine/comm_gemm_overlap.h | 8 +-- .../jax/csrc/extensions/cgemm_helper.cpp | 23 ++++--- .../jax/csrc/extensions/cgemm_helper.h | 11 ++-- .../jax/csrc/extensions/gemm.cpp | 20 +++--- transformer_engine/pytorch/csrc/extensions.h | 8 +-- 6 files changed, 67 insertions(+), 64 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 5098b3d8ffb..056cb3e4bf2 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -77,10 +77,11 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl num_comm_sm, set_sm_margin, use_ce, atomic_gemm); } -CommOverlapCore::CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size, - int num_comm_sm, bool is_p2p, bool atomic_gemm) { - NVTE_CHECK(nvte_built_with_cublasmp(), - "Comm+GEMM overlap with cuBLASMp backend requires TE to be built with NVTE_WITH_CUBLASMP=1."); +CommOverlapCore::CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm, + bool is_p2p, bool atomic_gemm) { + NVTE_CHECK( + nvte_built_with_cublasmp(), + "Comm+GEMM overlap with cuBLASMp backend requires TE to be built with NVTE_WITH_CUBLASMP=1."); _with_cublasmp = true; nvte_comm_gemm_ctx_create(reinterpret_cast(nccl_comm_ptr), tp_size, tp_rank); @@ -195,11 +196,11 @@ CommOverlapCore::~CommOverlapCore() { if (_comm_created) { try { - #ifdef NVTE_UB_WITH_MPI +#ifdef NVTE_UB_WITH_MPI destroy_communicator_mpi(_ub_comm); - #else +#else destroy_communicator(_ub_comm); - #endif +#endif } catch (const std::exception &e) { NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); } @@ -310,32 +311,32 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source return chunk; } -void CommOverlapCore::cublasmp_ag_gemm( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, bool grad, bool accumulate, - cudaStream_t stream_main) { +void CommOverlapCore::cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, bool grad, bool accumulate, + cudaStream_t stream_main) { int64_t m = transa ? A.size(0) : A.size(1); int64_t n_local = transb ? B.size(1) : B.size(0); int64_t n = n_local * _tp_size; int64_t k = transa ? A.size(1) : A.size(0); nvte_all_gather_gemm(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(), - pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, - stream_main, _algo_type); + pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, + stream_main, _algo_type); } -void CommOverlapCore::cublasmp_gemm_rs( - const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, bool grad, bool accumulate, - cudaStream_t stream_main) { +void CommOverlapCore::cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, bool grad, bool accumulate, + cudaStream_t stream_main) { int64_t m = transa ? A.size(0) : A.size(1); int64_t n = transb ? B.size(1) : B.size(0); int64_t k_local = transa ? A.size(1) : A.size(0); int64_t k = k * _tp_size; nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(), - pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, - stream_main, _algo_type); + pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, + stream_main, _algo_type); } /*************************************************************************************************** @@ -456,14 +457,15 @@ void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { if (_with_cublasmp) { - return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); } - + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - + // Get GEMM dimensions size_t m = transa ? A.size(0) : A.size(1); size_t k = transa ? A.size(1) : A.size(0); @@ -557,7 +559,8 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { if (_with_cublasmp) { - return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); } // Get GEMM dimensions @@ -872,7 +875,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { if (_with_cublasmp) { - return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); } int ori_sms = _ub_comm->sms; @@ -978,7 +982,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { if (_with_cublasmp) { - return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); } int ori_sms = _ub_comm->sms; @@ -1149,7 +1154,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { if (_with_cublasmp) { - return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); } int ori_sms = _ub_comm->sms; @@ -1217,7 +1223,8 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { if (_with_cublasmp) { - return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, stream_main); + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); } int ori_sms = _ub_comm->sms; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 74fb9d3d96c..254e491c258 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -92,8 +92,8 @@ class CommOverlapCore { int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm); - CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm, - bool is_p2p, bool atomic_gemm); + CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm, bool is_p2p, + bool atomic_gemm); virtual ~CommOverlapCore(); @@ -203,7 +203,7 @@ class CommOverlapBase : public CommOverlapCore { CommOverlapBase(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm = 16, bool atomic_gemm = false) - : CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size, num_comm_sm, false, atomic_gemm) {} + : CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size, num_comm_sm, false, atomic_gemm) {} virtual ~CommOverlapBase(); @@ -287,7 +287,7 @@ class CommOverlapP2PBase : public CommOverlapCore { CommOverlapP2PBase(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm = 1, bool atomic_gemm = false) - : CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size, num_comm_sm, true, atomic_gemm) {} + : CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size, num_comm_sm, true, atomic_gemm) {} virtual ~CommOverlapP2PBase(); diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp index 79b1d02def6..af478232014 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -138,15 +138,14 @@ void CommunicatorHandler::init(int num_total_devices, int num_devices_per_proces // Bootstrap UB/cuBlasMp via creating a dummy CommOverlapP2PBase object std::vector buffer_shape{1, 1}; - auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32, - JAXX_Collective_Op::ALL_GATHER, - use_cublasmp); + auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor( + buffer_shape, DType::kFloat32, JAXX_Collective_Op::ALL_GATHER, use_cublasmp); } void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, int tp_size, int num_max_streams, int gemm_priority, - int comm_priority, int num_comm_sm, bool use_ce, - bool aggregate_ag, bool use_cublasmp) { + int comm_priority, int num_comm_sm, bool use_ce, bool aggregate_ag, + bool use_cublasmp) { auto &config = CgemmConfig::get(false); config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag); auto &handler = CommunicatorHandler::get(false); @@ -196,18 +195,18 @@ CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector bu std::unique_ptr executor; if (use_cublasmp) { executor = std::make_unique( - reinterpret_cast(comm_handler.get_comm_for_current_device()), - comm_handler.tp_size, comm_handler.get_tp_domain_id(), - cgemm_config.num_comm_sm, cgemm_config.aggregate_ag); + reinterpret_cast(comm_handler.get_comm_for_current_device()), comm_handler.tp_size, + comm_handler.get_tp_domain_id(), cgemm_config.num_comm_sm, cgemm_config.aggregate_ag); } else { executor = std::make_unique( buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, - comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op), - cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority, - cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/, - cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag); + comm_handler.allgather_func, comm_handler.barrier_func, + get_nvte_collective_op(collective_op), cgemm_config.num_max_streams, 1 /*comm_cga_size*/, + cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm, + true /*set_sm_margin*/, cgemm_config.use_ce, false /*atomic_gemm*/, + cgemm_config.aggregate_ag); } CommOverlapCore *executor_ptr = executor.get(); diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h index 473b5f626c2..b1210398c0a 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.h +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -127,8 +127,8 @@ class CommunicatorHandler { int get_tp_num_domains() const { return tp_num_domains; } - static void init(int num_total_devices, int num_devices_per_process, int process_id, - int tp_size, bool use_cublasmp = false); + static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size, + bool use_cublasmp = false); private: ncclUniqueId coordinate_nccl_unique_id(const std::string &id_type); @@ -166,8 +166,7 @@ class CollectiveGemmPlanRegistry { } CommOverlapCore *get_executor(std::vector buffer_shape, DType dtype, - JAXX_Collective_Op collective_op, - bool use_cublasmp = false); + JAXX_Collective_Op collective_op, bool use_cublasmp = false); private: CollectiveGemmPlanRegistry() {} @@ -180,8 +179,8 @@ class CollectiveGemmPlanRegistry { // Function declarations void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, int tp_size, int num_max_streams, int gemm_priority, - int comm_priority, int num_comm_sm, bool use_ce, - bool aggregate_ag, bool use_cublasmp = false); + int comm_priority, int num_comm_sm, bool use_ce, bool aggregate_ag, + bool use_cublasmp = false); int GetCgemmNumMaxStreams(); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 9dfb806cf85..ac5ce949e4a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -88,16 +88,13 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } -Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, - Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type bias_grad, - Result_Type pre_gelu_out, Result_Type workspace, - JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, - int64_t rhs_axis_boundary, bool lhs_transposed, - bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, - bool use_split_accumulator, JAXX_Collective_Op collective_op, - bool use_cublasmp) { +Error_Type CollectiveGemmInitFFI( + Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, + Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, + bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op, bool use_cublasmp) { nvte_cublas_handle_init(); // Init UB buffer @@ -161,7 +158,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, - bool use_split_accumulator, JAXX_Collective_Op collective_op, bool use_cublasmp) { + bool use_split_accumulator, JAXX_Collective_Op collective_op, + bool use_cublasmp) { // cuBLAS workspace + 256 alignment enforcement (+ swizzle scales) uint8_t *lhs_swizzle_scale_ptr = nullptr, *rhs_swizzle_scale_ptr = nullptr; auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 3990a22ea07..db4183771a1 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -546,8 +546,8 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve CommOverlap(CommOverlapHelper *helper, int tp_size, int tp_rank, int num_comm_sm = 16, bool atomic_gemm = false) - : CommOverlapBase(helper->get_nccl_comm_ptr("intra"), tp_size, tp_rank, num_comm_sm, - atomic_gemm) {} + : CommOverlapBase(helper->get_nccl_comm_ptr("intra"), tp_size, tp_rank, num_comm_sm, + atomic_gemm) {} ~CommOverlap() {} @@ -572,8 +572,8 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm CommOverlapP2P(CommOverlapHelper *helper, int tp_size, int tp_rank, int num_comm_sm = 16, bool atomic_gemm = false) - : CommOverlapP2PBase(helper->get_nccl_comm_ptr("intra"), tp_size, tp_rank, num_comm_sm, - atomic_gemm) {} + : CommOverlapP2PBase(helper->get_nccl_comm_ptr("intra"), tp_size, tp_rank, num_comm_sm, + atomic_gemm) {} ~CommOverlapP2P() {} From dd8eaf318d6d0dc3603491387817f885999c9e0f Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 16 Dec 2025 21:27:46 +0000 Subject: [PATCH 5/9] added pytorch unit tests for comm+GEMM overlap with cuBLASMp backend Signed-off-by: Alp Dener --- .../distributed/run_gemm_with_overlap.py | 91 ++++++++++++------- .../distributed/run_layer_with_overlap.py | 7 ++ .../distributed/test_comm_gemm_overlap.py | 48 ++++++---- .../transformer_engine/comm_gemm_overlap.h | 2 +- .../jax/csrc/extensions/cgemm_helper.cpp | 2 +- transformer_engine/pytorch/csrc/extensions.h | 2 +- 6 files changed, 98 insertions(+), 54 deletions(-) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 073fa08117f..20e6b13db09 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -151,6 +151,9 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA graphs." ) + parser.add_argument( + "--use-cublasmp", action="store_true", default=False, help="Use cuBLASMp backend." + ) parser.add_argument( "-v", "--verbose", action="store_true", default=False, help="Verbose info messages." ) @@ -323,47 +326,65 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ): buffer_dtype = torch.uint8 ub_obj = ( - tex.CommOverlapP2P( - (outer_size, hidden_size), - buffer_dtype, - helper, - tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - opts.comm_type, - set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic, - atomic_gemm=opts.atomic, - aggregate=opts.aggregate, - use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), - ) - if opts.p2p - else tex.CommOverlap( - (outer_size, hidden_size), - buffer_dtype, - helper, - tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - atomic_gemm=opts.atomic, - ) - ) - - # Numerical check on AG + atomic GEMM requires testing an AG+RS pair - ub_obj2 = None - if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics: - ub_obj2 = ( + ( tex.CommOverlapP2P( (outer_size, hidden_size), - torch.uint8 if opts.fp8_output else torch.bfloat16, + buffer_dtype, helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - tex.CommOverlapType.RS, - set_sm_margin=True, - atomic_gemm=True, + opts.comm_type, + set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic, + atomic_gemm=opts.atomic, + aggregate=opts.aggregate, + use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), + ) if not opts.use_cublasmp + else tex.CommOverlapP2P( + helper, tp_size, tp_rank, atomic_gemm=opts.atomic, ) - if opts.atomic_rs_p2p - else tex.CommOverlap( + ) if opts.p2p + else ( + tex.CommOverlap( (outer_size, hidden_size), - torch.uint8 if opts.fp8_output else torch.bfloat16, + buffer_dtype, helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - atomic_gemm=True, + atomic_gemm=opts.atomic, + ) if not opts.use_cublasmp + else tex.CommOverlap( + helper, tp_size, tp_rank, atomic_gemm=opts.atomic, + ) + ) + ) + + # Numerical check on AG + atomic GEMM requires testing an AG+RS pair + ub_obj2 = None + if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics: + ub_obj2 = ( + ( + tex.CommOverlapP2P( + (outer_size, hidden_size), + torch.uint8 if opts.fp8_output else torch.bfloat16, + helper, + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + tex.CommOverlapType.RS, + set_sm_margin=True, + atomic_gemm=True, + ) if not opts.use_cublasmp + else tex.CommOverlapP2P( + helper, tp_size, tp_rank, atomic_gemm=True + ) + ) if opts.atomic_rs_p2p + else ( + tex.CommOverlap( + (outer_size, hidden_size), + torch.uint8 if opts.fp8_output else torch.bfloat16, + helper, + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + atomic_gemm=True, + ) if not opts.use_cublasmp + else tex.CommOverlap( + helper, tp_size, tp_rank, atomic_gemm=True + ) ) ) @@ -387,6 +408,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None if opts.comm_type == tex.CommOverlapType.AG: # (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P) local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size) + local_kernel2_t_shape = (0, ) local_inp_shape = (outer_size // tp_size, hidden_size) if ub_obj2 is not None: local_kernel2_t_shape = (hidden_size, ffn_hidden_size // tp_size) @@ -408,7 +430,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None mean=0.0, std=opts.std, ) - if ub_obj2 is not None: + if opts.comm_type == tex.CommOverlapType.AG and ub_obj2 is not None: kernel2_t = torch.nn.init.normal_( torch.empty(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"), mean=0.0, @@ -457,6 +479,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ref_g = torch.stack(bulk_inp_list).sum(dim=0) else: ref_g = torch.matmul(inp_g, ker_g) + ref2_g = (0, ) if ub_obj2 is not None: inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable ref2_g = torch.matmul(inp2_g, ker2_g) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index b2bd6dd773c..6842570d46d 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -258,6 +258,12 @@ def _parse_args(argv=None, namespace=None): default=0, help="Number of layers at the end to run in bf16.", ) + parser.add_argument( + "--use-cublasmp", + action="store_true", + default=False, + help="Use cuBLASMp backend.", + ) args = parser.parse_args(argv, namespace) if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: @@ -436,6 +442,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, + with_cublasmp=opts.use_cublasmp, ) with te.quantized_model_init(enabled=opts.fp8_init): diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 3f4848e1053..2f8aa1f7b48 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -50,7 +50,8 @@ torch._dynamo.reset() -def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization): +def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization, + use_cublasmp=False): test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -79,6 +80,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization test_cmd.append("--atomic") if aggregate: test_cmd.append("--aggregate") + if use_cublasmp: + test_cmd.append("--use-cublasmp") result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) if ( @@ -90,7 +93,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization def _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1 + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1, + use_cublasmp=False ): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ @@ -117,6 +121,9 @@ def _run_layer_with_overlap( test_cmd.append("--fp8") test_cmd.append(f"--quantization={quantization}") + if use_cublasmp: + test_cmd.append("--use-cublasmp") + os.environ["PYTORCH_JIT"] = "0" os.environ["NVTE_TORCH_COMPILE"] = "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" @@ -141,24 +148,26 @@ def _run_layer_with_overlap( raise AssertionError(result.stderr.decode()) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @pytest.mark.parametrize("aggregate", (False, True)) -def test_split_all_gather_overlaps(quantization, aggregate): +def test_split_all_gather_overlaps(quantization, aggregate, use_cublasmp): """ Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("AG", False, True, False, aggregate, quantization) + _run_gemm_with_overlap("AG", False, True, False, aggregate, quantization, use_cublasmp) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @pytest.mark.parametrize("p2p", (False, True)) -def test_split_reduce_scatter_overlaps(quantization, p2p): +def test_split_reduce_scatter_overlaps(quantization, p2p, use_cublasmp): """ Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("RS", False, p2p, False, False, quantization) + _run_gemm_with_overlap("RS", False, p2p, False, False, quantization, use_cublasmp) @pytest.mark.parametrize( @@ -196,7 +205,7 @@ def test_bulk_overlaps(comm_type, quantization, connections): else: _run_gemm_with_overlap(comm_type, True, False, False, False, quantization) - +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "fp8", (False,), @@ -237,13 +246,14 @@ def test_bulk_overlaps(comm_type, quantization, connections): ) ], ) -def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): +def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_dgrad, use_cublasmp, fp8): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None) - + _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, + use_cublasmp=use_cublasmp) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "quantization", ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"], @@ -286,13 +296,15 @@ def test_layers_with_overlap_fp8( linear_parallel_mode, overlap_rs_dgrad, quantization, + use_cublasmp, ): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization) - + _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, + use_cublasmp=use_cublasmp) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "fp8", (False,), @@ -325,16 +337,17 @@ def test_layers_with_overlap_fp8( ], ) def test_multi_layer_with_overlap_bf16( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, num_layers + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, num_layers, use_cublasmp ): """ Test Transformer Engine layers with comm+GEMM overlap. """ _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, num_layers + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, num_layers, + use_cublasmp=use_cublasmp ) - +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "quantization", ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"], @@ -361,11 +374,12 @@ def test_multi_layer_with_overlap_bf16( ], ) def test_multi_layer_with_overlap_fp8( - layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers + layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers, use_cublasmp ): """ Test Transformer Engine layers with comm+GEMM overlap. """ _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers + layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers, + use_cublasmp=use_cublasmp ) diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 254e491c258..136bd772718 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -20,7 +20,7 @@ /* \brief Check if TE is built with cuBLASMp. * - * \return True if TE is built with cuBlasMp. + * \return True if TE is built with cuBLASMp. */ bool nvte_built_with_cublasmp(); diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp index af478232014..b153aa5dc59 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -132,7 +132,7 @@ void CommunicatorHandler::init(int num_total_devices, int num_devices_per_proces NVTE_CHECK_NCCL(ncclGroupEnd()); // Allocate device memory for barrier operations - NVTE_CHECK_CUDA(cudaMalloc(&handler._device_barrier, sizeof(int))); + NVTE_CHECK_CUDA(cudaMalloc(&reinterpret_cast(handler._device_barrier), sizeof(int))); handler._initialize = true; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index db4183771a1..84354742101 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -525,7 +525,7 @@ class CommOverlapHelper : torch::CustomClassHolder { void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, ExtComm comm); - void ub_barrier(ExtComm comm); + void ub_barrier(ExtComm comm);a int64_t get_nccl_comm_ptr(std::string comm_name) { NVTE_CHECK(backend_is_nccl, From d79bf21637481d0c0f186a15d0e47e6db88717e7 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 17 Dec 2025 02:15:17 +0000 Subject: [PATCH 6/9] greptile fixes Signed-off-by: Alp Dener --- .../distributed/run_gemm_with_overlap.py | 2 -- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 2 +- .../jax/csrc/extensions/cgemm_helper.cpp | 7 ++++--- transformer_engine/pytorch/csrc/extensions.h | 20 ++++++++++++------- .../pytorch/csrc/extensions/pybind.cpp | 4 ++-- 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 20e6b13db09..586635fa036 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -408,7 +408,6 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None if opts.comm_type == tex.CommOverlapType.AG: # (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P) local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size) - local_kernel2_t_shape = (0, ) local_inp_shape = (outer_size // tp_size, hidden_size) if ub_obj2 is not None: local_kernel2_t_shape = (hidden_size, ffn_hidden_size // tp_size) @@ -479,7 +478,6 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ref_g = torch.stack(bulk_inp_list).sum(dim=0) else: ref_g = torch.matmul(inp_g, ker_g) - ref2_g = (0, ) if ub_obj2 is not None: inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable ref2_g = torch.matmul(inp2_g, ker2_g) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 056cb3e4bf2..e9bdba58726 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -332,7 +332,7 @@ void CommOverlapCore::cublasmp_gemm_rs(const TensorWrapper &A, bool transa, cons int64_t m = transa ? A.size(0) : A.size(1); int64_t n = transb ? B.size(1) : B.size(0); int64_t k_local = transa ? A.size(1) : A.size(0); - int64_t k = k * _tp_size; + int64_t k = k_local * _tp_size; nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp index b153aa5dc59..e50ce01536e 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -132,7 +132,7 @@ void CommunicatorHandler::init(int num_total_devices, int num_devices_per_proces NVTE_CHECK_NCCL(ncclGroupEnd()); // Allocate device memory for barrier operations - NVTE_CHECK_CUDA(cudaMalloc(&reinterpret_cast(handler._device_barrier), sizeof(int))); + NVTE_CHECK_CUDA(cudaMalloc(&handler._device_barrier, sizeof(int))); handler._initialize = true; @@ -195,8 +195,9 @@ CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector bu std::unique_ptr executor; if (use_cublasmp) { executor = std::make_unique( - reinterpret_cast(comm_handler.get_comm_for_current_device()), comm_handler.tp_size, - comm_handler.get_tp_domain_id(), cgemm_config.num_comm_sm, cgemm_config.aggregate_ag); + reinterpret_cast(comm_handler.get_comm_for_current_device()), + comm_handler.get_tp_domain_id(), comm_handler.tp_size, cgemm_config.num_comm_sm, + cgemm_config.aggregate_ag); } else { executor = std::make_unique( buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 84354742101..5ead49336aa 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -525,13 +525,19 @@ class CommOverlapHelper : torch::CustomClassHolder { void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, ExtComm comm); - void ub_barrier(ExtComm comm);a + void ub_barrier(ExtComm comm); int64_t get_nccl_comm_ptr(std::string comm_name) { +#ifdef USE_C10_NCCL NVTE_CHECK(backend_is_nccl, "Comm+GEMM overlap with cuBLASMp backend requires a tensor-parallel process ", "group with NCCL backend."); - return reinterpret_cast(pgs[comm_name])->getCommPtr(); + c10d::ProcessGroupNCCL *nccl_pg = reinterpret_cast(pgs[comm_name]); + return nccl_pg->getCommPtr(); +#else + NVTE_ERROR("Internal TE Error: CommOverlapHelper::get_nccl_comm_ptr() is an internal API that ", + "should only be used when TE is built with the NVTE_WITH_CUBLASMP=1 flag."); +#endif } }; @@ -542,11 +548,11 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, - bool rs_overlap_first_gemm = false); + bool rs_overlap_first_gemm= false); - CommOverlap(CommOverlapHelper *helper, int tp_size, int tp_rank, int num_comm_sm = 16, + CommOverlap(CommOverlapHelper *helper, int tp_rank, int tp_size, int num_comm_sm = 16, bool atomic_gemm = false) - : CommOverlapBase(helper->get_nccl_comm_ptr("intra"), tp_size, tp_rank, num_comm_sm, + : CommOverlapBase(helper->get_nccl_comm_ptr("intra"), tp_rank, tp_size, num_comm_sm, atomic_gemm) {} ~CommOverlap() {} @@ -570,9 +576,9 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, bool aggregate = false); - CommOverlapP2P(CommOverlapHelper *helper, int tp_size, int tp_rank, int num_comm_sm = 16, + CommOverlapP2P(CommOverlapHelper *helper, int tp_rank, int tp_size, int num_comm_sm = 16, bool atomic_gemm = false) - : CommOverlapP2PBase(helper->get_nccl_comm_ptr("intra"), tp_size, tp_rank, num_comm_sm, + : CommOverlapP2PBase(helper->get_nccl_comm_ptr("intra"), tp_rank, tp_size, num_comm_sm, atomic_gemm) {} ~CommOverlapP2P() {} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 000f892fc20..722010f8fad 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -491,7 +491,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) .def(py::init(), py::arg("helper"), - py::arg("tp_size"), py::arg("tp_rank"), py::arg("num_comm_sm") = 0, + py::arg("tp_rank"), py::arg("tp_size"), py::arg("num_comm_sm") = 0, py::arg("atomic_gemm") = false, py::call_guard()) .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) @@ -512,7 +512,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) .def(py::init(), py::arg("helper"), - py::arg("tp_size"), py::arg("tp_rank"), py::arg("num_comm_sm") = 0, + py::arg("tp_rank"), py::arg("tp_size"), py::arg("num_comm_sm") = 0, py::arg("atomic_gemm") = false, py::call_guard()) .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) From ee517d3592e1ebde8cdfa5de0a00567d6c46792e Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 17 Dec 2025 02:50:11 +0000 Subject: [PATCH 7/9] linting Signed-off-by: Alp Dener --- 3rdparty/cudnn-frontend | 2 +- tests/pytorch/distributed/run_gemm_with_overlap.py | 2 +- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 5 ++++- transformer_engine/pytorch/module/base.py | 3 +-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4d5..be6c079be8a 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 586635fa036..35d7fec6b79 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -24,7 +24,7 @@ MXFP8Quantizer, ) import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes + from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather warnings.filterwarnings("ignore", category=DeprecationWarning) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index e9bdba58726..047095aaf04 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -84,8 +84,11 @@ CommOverlapCore::CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size "Comm+GEMM overlap with cuBLASMp backend requires TE to be built with NVTE_WITH_CUBLASMP=1."); _with_cublasmp = true; - nvte_comm_gemm_ctx_create(reinterpret_cast(nccl_comm_ptr), tp_size, tp_rank); + _cublasmp_ctx = nvte_comm_gemm_ctx_create(reinterpret_cast(nccl_comm_ptr), tp_size, + tp_rank); + _tp_id = tp_rank; + _tp_size = tp_size; _num_comm_sm = num_comm_sm; _is_p2p = is_p2p; _atomic_gemm = atomic_gemm; diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 457e798b2b3..4d4bbac12fa 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -331,7 +331,6 @@ def add_ub( comm_priority: int = 0, gemm_priority: int = 0, pipeline_rs_overlap_first_gemm: bool = False, - with_cublasmp: bool = False, ) -> None: if atomic_gemm: warnings.warn( @@ -506,7 +505,7 @@ def fill_userbuffers_buffer_for_all_gather( """ # cuBlasMp already handles its own buffer filling and quantization factors if comm.with_cublasmp(): - return + return local_tensor, local_tensor # Tensor dimensions local_shape = local_tensor.size() From 51b64fb4cb86f8451ed9e7649be5a4b976890419 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 17 Dec 2025 03:35:45 +0000 Subject: [PATCH 8/9] function argument call order fixes Signed-off-by: Alp Dener --- tests/pytorch/distributed/run_gemm_with_overlap.py | 11 ++++++----- transformer_engine/pytorch/module/base.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 35d7fec6b79..539eaacbdf4 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -24,7 +24,7 @@ MXFP8Quantizer, ) import transformer_engine.pytorch.cpp_extensions as tex - +from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -206,6 +206,7 @@ def _main(opts): capture_output=True, text=True, shell=True, + check=False, ) if result.stdout == "0": # Extra checks for non-MNNVL platforms @@ -339,7 +340,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), ) if not opts.use_cublasmp else tex.CommOverlapP2P( - helper, tp_size, tp_rank, atomic_gemm=opts.atomic, + helper, tp_rank, tp_size, atomic_gemm=opts.atomic, ) ) if opts.p2p else ( @@ -351,7 +352,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None atomic_gemm=opts.atomic, ) if not opts.use_cublasmp else tex.CommOverlap( - helper, tp_size, tp_rank, atomic_gemm=opts.atomic, + helper, tp_rank, tp_size, atomic_gemm=opts.atomic, ) ) ) @@ -371,7 +372,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None atomic_gemm=True, ) if not opts.use_cublasmp else tex.CommOverlapP2P( - helper, tp_size, tp_rank, atomic_gemm=True + helper, tp_rank, tp_size, atomic_gemm=True ) ) if opts.atomic_rs_p2p else ( @@ -383,7 +384,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None atomic_gemm=True, ) if not opts.use_cublasmp else tex.CommOverlap( - helper, tp_size, tp_rank, atomic_gemm=True + helper, tp_rank, tp_size, atomic_gemm=True ) ) ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 4d4bbac12fa..c9ea961caab 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -387,8 +387,8 @@ def add_ub( if with_cublasmp: ub_obj = tex.CommOverlapP2P( helper, - tp_size, local_rank, + tp_size, num_comm_sm=num_sm, atomic_gemm=atomic_gemm, ) @@ -413,8 +413,8 @@ def add_ub( if with_cublasmp and method != "bulk": ub_obj = tex.CommOverlap( helper, - tp_size, local_rank, + tp_size, num_comm_sm=num_sm, atomic_gemm=atomic_gemm, ) From 9be771c41cae6c063c7bcc019a5225deeaa43106 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Dec 2025 03:37:41 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../distributed/run_gemm_with_overlap.py | 32 +++++++---- .../distributed/test_comm_gemm_overlap.py | 55 ++++++++++++++----- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 4 +- transformer_engine/pytorch/csrc/extensions.h | 2 +- 4 files changed, 65 insertions(+), 28 deletions(-) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 539eaacbdf4..d595bd3677b 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -338,11 +338,16 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None atomic_gemm=opts.atomic, aggregate=opts.aggregate, use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), - ) if not opts.use_cublasmp + ) + if not opts.use_cublasmp else tex.CommOverlapP2P( - helper, tp_rank, tp_size, atomic_gemm=opts.atomic, + helper, + tp_rank, + tp_size, + atomic_gemm=opts.atomic, ) - ) if opts.p2p + ) + if opts.p2p else ( tex.CommOverlap( (outer_size, hidden_size), @@ -350,9 +355,13 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) atomic_gemm=opts.atomic, - ) if not opts.use_cublasmp + ) + if not opts.use_cublasmp else tex.CommOverlap( - helper, tp_rank, tp_size, atomic_gemm=opts.atomic, + helper, + tp_rank, + tp_size, + atomic_gemm=opts.atomic, ) ) ) @@ -370,11 +379,11 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None tex.CommOverlapType.RS, set_sm_margin=True, atomic_gemm=True, - ) if not opts.use_cublasmp - else tex.CommOverlapP2P( - helper, tp_rank, tp_size, atomic_gemm=True ) - ) if opts.atomic_rs_p2p + if not opts.use_cublasmp + else tex.CommOverlapP2P(helper, tp_rank, tp_size, atomic_gemm=True) + ) + if opts.atomic_rs_p2p else ( tex.CommOverlap( (outer_size, hidden_size), @@ -382,10 +391,9 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) atomic_gemm=True, - ) if not opts.use_cublasmp - else tex.CommOverlap( - helper, tp_rank, tp_size, atomic_gemm=True ) + if not opts.use_cublasmp + else tex.CommOverlap(helper, tp_rank, tp_size, atomic_gemm=True) ) ) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 2f8aa1f7b48..107c1584590 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -50,8 +50,9 @@ torch._dynamo.reset() -def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization, - use_cublasmp=False): +def _run_gemm_with_overlap( + comm_type, bulk, p2p, atomic, aggregate, quantization, use_cublasmp=False +): test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -93,8 +94,13 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization def _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1, - use_cublasmp=False + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + fp8, + quantization, + num_layers=1, + use_cublasmp=False, ): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ @@ -205,6 +211,7 @@ def test_bulk_overlaps(comm_type, quantization, connections): else: _run_gemm_with_overlap(comm_type, True, False, False, False, quantization) + @pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "fp8", @@ -246,12 +253,16 @@ def test_bulk_overlaps(comm_type, quantization, connections): ) ], ) -def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_dgrad, use_cublasmp, fp8): +def test_layers_with_overlap_bf16( + layer_type, linear_parallel_mode, overlap_rs_dgrad, use_cublasmp, fp8 +): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, - use_cublasmp=use_cublasmp) + _run_layer_with_overlap( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, use_cublasmp=use_cublasmp + ) + @pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( @@ -301,8 +312,15 @@ def test_layers_with_overlap_fp8( """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, - use_cublasmp=use_cublasmp) + _run_layer_with_overlap( + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + True, + quantization, + use_cublasmp=use_cublasmp, + ) + @pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( @@ -343,10 +361,16 @@ def test_multi_layer_with_overlap_bf16( Test Transformer Engine layers with comm+GEMM overlap. """ _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, num_layers, - use_cublasmp=use_cublasmp + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + fp8, + None, + num_layers, + use_cublasmp=use_cublasmp, ) + @pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "quantization", @@ -380,6 +404,11 @@ def test_multi_layer_with_overlap_fp8( Test Transformer Engine layers with comm+GEMM overlap. """ _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers, - use_cublasmp=use_cublasmp + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + True, + quantization, + num_layers, + use_cublasmp=use_cublasmp, ) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 047095aaf04..dbe6a9d27c1 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -84,8 +84,8 @@ CommOverlapCore::CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size "Comm+GEMM overlap with cuBLASMp backend requires TE to be built with NVTE_WITH_CUBLASMP=1."); _with_cublasmp = true; - _cublasmp_ctx = nvte_comm_gemm_ctx_create(reinterpret_cast(nccl_comm_ptr), tp_size, - tp_rank); + _cublasmp_ctx = + nvte_comm_gemm_ctx_create(reinterpret_cast(nccl_comm_ptr), tp_size, tp_rank); _tp_id = tp_rank; _tp_size = tp_size; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 5ead49336aa..158f2b94bde 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -548,7 +548,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, - bool rs_overlap_first_gemm= false); + bool rs_overlap_first_gemm = false); CommOverlap(CommOverlapHelper *helper, int tp_rank, int tp_size, int num_comm_sm = 16, bool atomic_gemm = false)