diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4d5..be6c079be8a 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index b03ef04fa42..82df324cd7d 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -96,6 +96,12 @@ def setup_pytorch_extension( libraries.append("nvshmem_host") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") + if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))): + # Creating a cuBlasMp context requires direct access to the underlying NCCL + # communicator in a tensor-parallel process group. The header for ProcessGroupNCCL + # needs this CPP directive to be included properly. + cxx_flags.append("-DUSE_C10D_NCCL") + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index da79b213772..3ef786efa94 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -154,6 +154,7 @@ def _initialize_distributed(args): num_devices_per_process=devices_per_process, process_id=args.process_id, tensor_parallel_size=args.tensor_parallel_size, + use_cublasmp=args.use_cublasmp, ) @@ -241,5 +242,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para parser.add_argument( "--enable-result-check", action="store_true", default=True, help="Enable result checking" ) + parser.add_argument( + "--use-cublasmp", + action="store_true", + default=False, + help="Use the cuBLASMp backend for overlapping collective operations with GEMM computation", + ) return parser diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index af263eb53d0..13b5daad8b8 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -65,50 +65,58 @@ for TEST_FILE in "${TEST_FILES[@]}"; do # Clear PIDs array for this test file PIDS=() - for i in $(seq 0 $(($NUM_GPUS - 1))); do - # Define output file for logs - LOG_FILE="${TEST_FILE}_gpu_${i}.log" - - if [ $i -eq 0 ]; then - # For process 0: show live output AND save to log file using tee - echo "=== Live output from process 0 ===" - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ - "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ - --num-processes=$NUM_GPUS \ - --process-id=$i 2>&1 | tee "$LOG_FILE" & - PID=$! - PIDS+=($PID) - else - # For other processes: redirect to log files only - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ - --num-processes=$NUM_GPUS \ - --process-id=$i > "$LOG_FILE" 2>&1 & - PID=$! - PIDS+=($PID) - fi + PYTEST_ARGS=( + "-vs" + "-c $TE_PATH/tests/jax/pytest.ini" + "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" + "--num-processes=$NUM_GPUS" + ) + + BACKENDS=("cublasmp" "userbuffers") + for backend in "${BACKENDS[@]}"; do + for i in $(seq 0 $(($NUM_GPUS - 1))); do + # Define output file for logs + LOG_FILE="${TEST_FILE}_gpu_${i}_${backend}.log" + + if [ $i -eq 0 ]; then + # For process 0: show live output AND save to log file using tee + echo "=== Live output from process 0 with ${backend} ===" + pytest --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ + "${PYTEST_ARGS[@]}" \ + --process-id=$i 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + # For other processes: redirect to log files only + pytest "${PYTEST_ARGS[@]}" \ + --process-id=$i > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done done # Wait for all processes to finish wait # Check and print the log content from process 0 (now has log file thanks to tee) - if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE SKIPPED" - elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE FAILED" - HAS_FAILURE=1 - elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then - echo "... $TEST_FILE PASSED" - else - echo "... $TEST_FILE INVALID" - HAS_FAILURE=1 - fi - - # Remove the log files after processing them - wait - rm ${TEST_FILE}_gpu_*.log + for backend in "${BACKENDS[@]}"; do + if grep -q "SKIPPED" "${TEST_FILE}_gpu_0_${backend}.log"; then + echo "... $TEST_FILE SKIPPED for ${backend} backend" + elif grep -q "FAILED" "${TEST_FILE}_gpu_0_${backend}.log"; then + echo "... $TEST_FILE FAILED for ${backend} backend" + HAS_FAILURE=1 + elif grep -q "PASSED" "${TEST_FILE}_gpu_0_${backend}.log"; then + echo "... $TEST_FILE PASSED for ${backend} backend" + else + echo "... $TEST_FILE INVALID for ${backend} backend" + HAS_FAILURE=1 + fi + + # Remove the log files after processing them + wait + rm ${TEST_FILE}_gpu_*.log + done done wait diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 073fa08117f..d595bd3677b 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -151,6 +151,9 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA graphs." ) + parser.add_argument( + "--use-cublasmp", action="store_true", default=False, help="Use cuBLASMp backend." + ) parser.add_argument( "-v", "--verbose", action="store_true", default=False, help="Verbose info messages." ) @@ -203,6 +206,7 @@ def _main(opts): capture_output=True, text=True, shell=True, + check=False, ) if result.stdout == "0": # Extra checks for non-MNNVL platforms @@ -323,24 +327,42 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ): buffer_dtype = torch.uint8 ub_obj = ( - tex.CommOverlapP2P( - (outer_size, hidden_size), - buffer_dtype, - helper, - tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - opts.comm_type, - set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic, - atomic_gemm=opts.atomic, - aggregate=opts.aggregate, - use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), + ( + tex.CommOverlapP2P( + (outer_size, hidden_size), + buffer_dtype, + helper, + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + opts.comm_type, + set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic, + atomic_gemm=opts.atomic, + aggregate=opts.aggregate, + use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), + ) + if not opts.use_cublasmp + else tex.CommOverlapP2P( + helper, + tp_rank, + tp_size, + atomic_gemm=opts.atomic, + ) ) if opts.p2p - else tex.CommOverlap( - (outer_size, hidden_size), - buffer_dtype, - helper, - tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - atomic_gemm=opts.atomic, + else ( + tex.CommOverlap( + (outer_size, hidden_size), + buffer_dtype, + helper, + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + atomic_gemm=opts.atomic, + ) + if not opts.use_cublasmp + else tex.CommOverlap( + helper, + tp_rank, + tp_size, + atomic_gemm=opts.atomic, + ) ) ) @@ -348,22 +370,30 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ub_obj2 = None if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics: ub_obj2 = ( - tex.CommOverlapP2P( - (outer_size, hidden_size), - torch.uint8 if opts.fp8_output else torch.bfloat16, - helper, - tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - tex.CommOverlapType.RS, - set_sm_margin=True, - atomic_gemm=True, + ( + tex.CommOverlapP2P( + (outer_size, hidden_size), + torch.uint8 if opts.fp8_output else torch.bfloat16, + helper, + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + tex.CommOverlapType.RS, + set_sm_margin=True, + atomic_gemm=True, + ) + if not opts.use_cublasmp + else tex.CommOverlapP2P(helper, tp_rank, tp_size, atomic_gemm=True) ) if opts.atomic_rs_p2p - else tex.CommOverlap( - (outer_size, hidden_size), - torch.uint8 if opts.fp8_output else torch.bfloat16, - helper, - tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - atomic_gemm=True, + else ( + tex.CommOverlap( + (outer_size, hidden_size), + torch.uint8 if opts.fp8_output else torch.bfloat16, + helper, + tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + atomic_gemm=True, + ) + if not opts.use_cublasmp + else tex.CommOverlap(helper, tp_rank, tp_size, atomic_gemm=True) ) ) @@ -408,7 +438,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None mean=0.0, std=opts.std, ) - if ub_obj2 is not None: + if opts.comm_type == tex.CommOverlapType.AG and ub_obj2 is not None: kernel2_t = torch.nn.init.normal_( torch.empty(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"), mean=0.0, diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index b2bd6dd773c..6842570d46d 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -258,6 +258,12 @@ def _parse_args(argv=None, namespace=None): default=0, help="Number of layers at the end to run in bf16.", ) + parser.add_argument( + "--use-cublasmp", + action="store_true", + default=False, + help="Use cuBLASMp backend.", + ) args = parser.parse_args(argv, namespace) if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: @@ -436,6 +442,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, + with_cublasmp=opts.use_cublasmp, ) with te.quantized_model_init(enabled=opts.fp8_init): diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 3f4848e1053..107c1584590 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -50,7 +50,9 @@ torch._dynamo.reset() -def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization): +def _run_gemm_with_overlap( + comm_type, bulk, p2p, atomic, aggregate, quantization, use_cublasmp=False +): test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -79,6 +81,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization test_cmd.append("--atomic") if aggregate: test_cmd.append("--aggregate") + if use_cublasmp: + test_cmd.append("--use-cublasmp") result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) if ( @@ -90,7 +94,13 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization def _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1 + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + fp8, + quantization, + num_layers=1, + use_cublasmp=False, ): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ @@ -117,6 +127,9 @@ def _run_layer_with_overlap( test_cmd.append("--fp8") test_cmd.append(f"--quantization={quantization}") + if use_cublasmp: + test_cmd.append("--use-cublasmp") + os.environ["PYTORCH_JIT"] = "0" os.environ["NVTE_TORCH_COMPILE"] = "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" @@ -141,24 +154,26 @@ def _run_layer_with_overlap( raise AssertionError(result.stderr.decode()) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @pytest.mark.parametrize("aggregate", (False, True)) -def test_split_all_gather_overlaps(quantization, aggregate): +def test_split_all_gather_overlaps(quantization, aggregate, use_cublasmp): """ Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("AG", False, True, False, aggregate, quantization) + _run_gemm_with_overlap("AG", False, True, False, aggregate, quantization, use_cublasmp) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @pytest.mark.parametrize("p2p", (False, True)) -def test_split_reduce_scatter_overlaps(quantization, p2p): +def test_split_reduce_scatter_overlaps(quantization, p2p, use_cublasmp): """ Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("RS", False, p2p, False, False, quantization) + _run_gemm_with_overlap("RS", False, p2p, False, False, quantization, use_cublasmp) @pytest.mark.parametrize( @@ -197,6 +212,7 @@ def test_bulk_overlaps(comm_type, quantization, connections): _run_gemm_with_overlap(comm_type, True, False, False, False, quantization) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "fp8", (False,), @@ -237,13 +253,18 @@ def test_bulk_overlaps(comm_type, quantization, connections): ) ], ) -def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): +def test_layers_with_overlap_bf16( + layer_type, linear_parallel_mode, overlap_rs_dgrad, use_cublasmp, fp8 +): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None) + _run_layer_with_overlap( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, use_cublasmp=use_cublasmp + ) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "quantization", ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"], @@ -286,13 +307,22 @@ def test_layers_with_overlap_fp8( linear_parallel_mode, overlap_rs_dgrad, quantization, + use_cublasmp, ): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization) + _run_layer_with_overlap( + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + True, + quantization, + use_cublasmp=use_cublasmp, + ) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "fp8", (False,), @@ -325,16 +355,23 @@ def test_layers_with_overlap_fp8( ], ) def test_multi_layer_with_overlap_bf16( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, num_layers + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, num_layers, use_cublasmp ): """ Test Transformer Engine layers with comm+GEMM overlap. """ _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, num_layers + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + fp8, + None, + num_layers, + use_cublasmp=use_cublasmp, ) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "quantization", ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"], @@ -361,11 +398,17 @@ def test_multi_layer_with_overlap_bf16( ], ) def test_multi_layer_with_overlap_fp8( - layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers + layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers, use_cublasmp ): """ Test Transformer Engine layers with comm+GEMM overlap. """ _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + True, + quantization, + num_layers, + use_cublasmp=use_cublasmp, ) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 264f7f9a78d..ee3a90b8e3f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -271,7 +271,7 @@ if (NVTE_WITH_CUBLASMP) target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) find_library(CUBLASMP_LIB - NAMES cublasmp libcublasmp + NAMES cublasmp libcublasmp.so.0 PATHS ${CUBLASMP_DIR} PATH_SUFFIXES lib REQUIRED) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 56369db27fa..dbe6a9d27c1 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -23,6 +23,14 @@ using namespace std::placeholders; +bool nvte_built_with_cublasmp() { +#ifdef NVTE_WITH_CUBLASMP + return true; +#else + return false; +#endif +} + namespace transformer_engine { namespace { @@ -33,10 +41,6 @@ std::vector shape_to_vector(const NVTEShape &shape) { } // namespace -/*************************************************************************************************** - * Comm+GEMM Overlap Common Core - **************************************************************************************************/ - bool ubuf_built_with_mpi() { #ifdef NVTE_UB_WITH_MPI return true; @@ -45,6 +49,10 @@ bool ubuf_built_with_mpi() { #endif } +/*************************************************************************************************** + * Comm+GEMM Overlap Common Core + **************************************************************************************************/ + CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, @@ -69,6 +77,36 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl num_comm_sm, set_sm_margin, use_ce, atomic_gemm); } +CommOverlapCore::CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm, + bool is_p2p, bool atomic_gemm) { + NVTE_CHECK( + nvte_built_with_cublasmp(), + "Comm+GEMM overlap with cuBLASMp backend requires TE to be built with NVTE_WITH_CUBLASMP=1."); + _with_cublasmp = true; + + _cublasmp_ctx = + nvte_comm_gemm_ctx_create(reinterpret_cast(nccl_comm_ptr), tp_size, tp_rank); + + _tp_id = tp_rank; + _tp_size = tp_size; + _num_comm_sm = num_comm_sm; + _is_p2p = is_p2p; + _atomic_gemm = atomic_gemm; + if (_is_p2p) { + if (_atomic_gemm) { + _algo_type = kNVTECommGemmAlgoAtomicP2P; + } else { + _algo_type = kNVTECommGemmAlgoSplitP2P; + } + } else { + if (_atomic_gemm) { + _algo_type = kNVTECommGemmAlgoAtomicMulticast; + } else { + _algo_type = kNVTECommGemmAlgoSplitMulticast; + } + } +} + void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, @@ -134,39 +172,43 @@ void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_stream } CommOverlapCore::~CommOverlapCore() { - cudaEventDestroy(_stop_comm); - cudaEventDestroy(_start_comm); - cudaEventDestroy(_stop_compute); - cudaEventDestroy(_start_compute); - if (_comm_launch_event) { - cudaEventDestroy(_comm_launch_event); - } + if (_with_cublasmp) { + nvte_comm_gemm_ctx_destroy(_cublasmp_ctx); + } else { + cudaEventDestroy(_stop_comm); + cudaEventDestroy(_start_comm); + cudaEventDestroy(_stop_compute); + cudaEventDestroy(_start_compute); + if (_comm_launch_event) { + cudaEventDestroy(_comm_launch_event); + } - if (_atomic_gemm) { - cudaFree(_counter.dptr()); - } + if (_atomic_gemm) { + cudaFree(_counter.dptr()); + } - for (size_t i = 0; i < _stream_compute.size(); i++) { - cudaStreamSynchronize(_stream_compute[i]); - cudaStreamDestroy(_stream_compute[i]); - } + for (size_t i = 0; i < _stream_compute.size(); i++) { + cudaStreamSynchronize(_stream_compute[i]); + cudaStreamDestroy(_stream_compute[i]); + } - auto error = cudaGetLastError(); - if (error != cudaSuccess) { - NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); - } + auto error = cudaGetLastError(); + if (error != cudaSuccess) { + NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); + } - if (_comm_created) { - try { + if (_comm_created) { + try { #ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); + destroy_communicator_mpi(_ub_comm); #else - destroy_communicator(_ub_comm); + destroy_communicator(_ub_comm); #endif - } catch (const std::exception &e) { - NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); + } catch (const std::exception &e) { + NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); + } + _comm_created = false; } - _comm_created = false; } } @@ -272,6 +314,34 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source return chunk; } +void CommOverlapCore::cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, bool grad, bool accumulate, + cudaStream_t stream_main) { + int64_t m = transa ? A.size(0) : A.size(1); + int64_t n_local = transb ? B.size(1) : B.size(0); + int64_t n = n_local * _tp_size; + int64_t k = transa ? A.size(1) : A.size(0); + + nvte_all_gather_gemm(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, + stream_main, _algo_type); +} + +void CommOverlapCore::cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, bool grad, bool accumulate, + cudaStream_t stream_main) { + int64_t m = transa ? A.size(0) : A.size(1); + int64_t n = transb ? B.size(1) : B.size(0); + int64_t k_local = transa ? A.size(1) : A.size(0); + int64_t k = k_local * _tp_size; + + nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, + stream_main, _algo_type); +} + /*************************************************************************************************** * Comm+GEMM Overlap Base (Pipelined / Collective) **************************************************************************************************/ @@ -313,9 +383,11 @@ void CommOverlapBase::initialize(const std::vector &buffer_shape, DType } CommOverlapBase::~CommOverlapBase() { - cudaEventDestroy(_start_d2dcopy); - cudaStreamSynchronize(_stream_comm); - cudaStreamDestroy(_stream_comm); + if (!_with_cublasmp) { + cudaEventDestroy(_start_d2dcopy); + cudaStreamSynchronize(_stream_comm); + cudaStreamDestroy(_stream_comm); + } } /* @@ -328,6 +400,8 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_CHECK(!_with_cublasmp, "Bulk overlap is not supported with cuBlasMp"); + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -385,10 +459,16 @@ void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions size_t m = transa ? A.size(0) : A.size(1); size_t k = transa ? A.size(1) : A.size(0); @@ -481,6 +561,11 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -724,16 +809,20 @@ void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DTy } CommOverlapP2PBase::~CommOverlapP2PBase() { - cudaEventDestroy(_stop_recv); - cudaEventDestroy(_stop_send); - cudaStreamDestroy(_stream_recv); - for (size_t i = 0; i < _stream_send.size(); i++) { - cudaStreamDestroy(_stream_send[i]); + if (!_with_cublasmp) { + cudaEventDestroy(_stop_recv); + cudaEventDestroy(_stop_send); + cudaStreamDestroy(_stream_recv); + for (size_t i = 0; i < _stream_send.size(); i++) { + cudaStreamDestroy(_stream_send[i]); + } } } void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, bool rowwise) { + if (_with_cublasmp) return; // cuBlasMp executes its own copy-into-buffer op + // Check element size const size_t element_size = source.element_size(); NVTE_CHECK(_ubuf.element_size() == element_size, @@ -788,6 +877,11 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -890,6 +984,11 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -1057,6 +1156,11 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -1121,6 +1225,11 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index cffc411a0d3..136bd772718 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -17,6 +18,12 @@ #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 +/* \brief Check if TE is built with cuBLASMp. + * + * \return True if TE is built with cuBLASMp. + */ +bool nvte_built_with_cublasmp(); + namespace transformer_engine { /* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. @@ -59,6 +66,10 @@ class CommOverlapCore { bool _atomic_gemm{false}; bool _is_p2p{false}; + bool _with_cublasmp{false}; + NVTECommGemmCtx *_cublasmp_ctx{nullptr}; + NVTECommGemmAlgoType _algo_type = kNVTECommGemmAlgoDefault; + TensorWrapper _ubuf; TensorWrapper _counter; float *_ubuf_scale_inv; @@ -81,6 +92,9 @@ class CommOverlapCore { int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm); + CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm, bool is_p2p, + bool atomic_gemm); + virtual ~CommOverlapCore(); void *get_ubuf_dptr() { return _ubuf.dptr(); } @@ -109,6 +123,16 @@ class CommOverlapCore { bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + bool with_cublasmp() { return _with_cublasmp; } + + void cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + bool grad, bool accumulate, cudaStream_t stream_main); + + void cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + bool grad, bool accumulate, cudaStream_t stream_main); + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, @@ -177,6 +201,10 @@ class CommOverlapBase : public CommOverlapCore { bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); + CommOverlapBase(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm = 16, + bool atomic_gemm = false) + : CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size, num_comm_sm, false, atomic_gemm) {} + virtual ~CommOverlapBase(); /* @@ -257,6 +285,10 @@ class CommOverlapP2PBase : public CommOverlapCore { int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); + CommOverlapP2PBase(int64_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm = 1, + bool atomic_gemm = false) + : CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size, num_comm_sm, true, atomic_gemm) {} + virtual ~CommOverlapP2PBase(); void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index bce124e705f..34e664ccbf0 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -112,6 +112,8 @@ .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ py::call_guard()) \ .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ + py::call_guard()) \ + .def("with_cublasmp", &transformer_engine::CommOverlapCore::with_cublasmp, \ py::call_guard()); \ py::class_, \ @@ -135,6 +137,8 @@ }, \ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); \ + m.def("nvte_built_with_cublasmp", &nvte_built_with_cublasmp, \ py::call_guard()); #endif diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 76a8b225ba8..e2f4b2377df 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -70,6 +70,7 @@ num_cublas_streams = get_num_compute_streams() +collective_gemm_with_cublasmp = False def get_cublas_workspace_size_bytes() -> None: @@ -198,6 +199,7 @@ def collective_gemm_bootstrap( num_sm_for_communication=2, use_ce=True, aggregate_all_gather=False, + use_cublasmp=False, ): """Initialize NCCL communicators for Collective GEMM operations. @@ -281,6 +283,8 @@ def collective_gemm_bootstrap( f" num_devices_per_process={num_devices_per_process}" ) assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}" + global collective_gemm_with_cublasmp + collective_gemm_with_cublasmp = use_cublasmp initialize_cgemm_communicator( num_total_devices, num_devices_per_process, @@ -292,6 +296,7 @@ def collective_gemm_bootstrap( num_sm_for_communication, use_ce, aggregate_all_gather, + use_cublasmp, ) @@ -538,7 +543,11 @@ def _dims_are_consecutive(dims): if scaling_mode.is_nvfp4_scaling: workspace_size += lhs_scale_inv.size + rhs_scale_inv.size if not collective_op.is_none: - workspace_size *= get_cgemm_num_max_streams() + if collective_gemm_with_cublasmp: + # cuBlasMp manages its own cuBlasLt workspaces per stream + workspace_size = 0 + else: + workspace_size *= get_cgemm_num_max_streams() # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. workspace_size += 256 @@ -617,6 +626,7 @@ def lowering( "grad": grad, "use_split_accumulator": use_split_accumulator, "collective_op": int(collective_op.value), + "use_cublasmp": collective_gemm_with_cublasmp, } operand_output_aliases = {} @@ -803,10 +813,10 @@ def batcher( fuse_gelu, grad, use_split_accumulator, - collective_op, transpose_batch_sequence, sequence_dim, is_outer, + collective_op, ): del transpose_batch_sequence, sequence_dim, is_outer assert GemmPrimitive.outer_primitive is not None diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp index 7082bfb035b..e50ce01536e 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -62,7 +62,7 @@ ncclUniqueId CommunicatorHandler::coordinate_nccl_unique_id(const std::string &i } void CommunicatorHandler::init(int num_total_devices, int num_devices_per_process, int process_id, - int tp_size) { + int tp_size, bool use_cublasmp) { // Validate inputs NVTE_CHECK(num_devices_per_process == 1, "num_devices_per_process must be == 1, got num_devices_per_process=", @@ -136,20 +136,20 @@ void CommunicatorHandler::init(int num_total_devices, int num_devices_per_proces handler._initialize = true; - // Bootstrap UB via creating a dummy CommOverlapP2PBase object + // Bootstrap UB/cuBlasMp via creating a dummy CommOverlapP2PBase object std::vector buffer_shape{1, 1}; - auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32, - JAXX_Collective_Op::ALL_GATHER); + auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor( + buffer_shape, DType::kFloat32, JAXX_Collective_Op::ALL_GATHER, use_cublasmp); } void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, int tp_size, int num_max_streams, int gemm_priority, - int comm_priority, int num_comm_sm, bool use_ce, - bool aggregate_ag) { + int comm_priority, int num_comm_sm, bool use_ce, bool aggregate_ag, + bool use_cublasmp) { auto &config = CgemmConfig::get(false); config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag); auto &handler = CommunicatorHandler::get(false); - handler.init(num_total_devices, num_devices_per_process, process_id, tp_size); + handler.init(num_total_devices, num_devices_per_process, process_id, tp_size, use_cublasmp); } int GetCgemmNumMaxStreams() { @@ -159,7 +159,8 @@ int GetCgemmNumMaxStreams() { CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector buffer_shape, DType dtype, - JAXX_Collective_Op collective_op) { + JAXX_Collective_Op collective_op, + bool use_cublasmp) { auto &comm_handler = CommunicatorHandler::get(); auto &cgemm_config = CgemmConfig::get(); @@ -192,14 +193,22 @@ CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector bu } std::unique_ptr executor; - executor = std::make_unique( - buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, - comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, - comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, - comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op), - cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority, - cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/, - cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag); + if (use_cublasmp) { + executor = std::make_unique( + reinterpret_cast(comm_handler.get_comm_for_current_device()), + comm_handler.get_tp_domain_id(), comm_handler.tp_size, cgemm_config.num_comm_sm, + cgemm_config.aggregate_ag); + } else { + executor = std::make_unique( + buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, + comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, + comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, + comm_handler.allgather_func, comm_handler.barrier_func, + get_nvte_collective_op(collective_op), cgemm_config.num_max_streams, 1 /*comm_cga_size*/, + cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm, + true /*set_sm_margin*/, cgemm_config.use_ce, false /*atomic_gemm*/, + cgemm_config.aggregate_ag); + } CommOverlapCore *executor_ptr = executor.get(); plan_map[plan_id] = std::move(executor); diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h index 84b2b815401..b1210398c0a 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.h +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -127,7 +127,8 @@ class CommunicatorHandler { int get_tp_num_domains() const { return tp_num_domains; } - static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size); + static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size, + bool use_cublasmp = false); private: ncclUniqueId coordinate_nccl_unique_id(const std::string &id_type); @@ -165,7 +166,7 @@ class CollectiveGemmPlanRegistry { } CommOverlapCore *get_executor(std::vector buffer_shape, DType dtype, - JAXX_Collective_Op collective_op); + JAXX_Collective_Op collective_op, bool use_cublasmp = false); private: CollectiveGemmPlanRegistry() {} @@ -178,8 +179,8 @@ class CollectiveGemmPlanRegistry { // Function declarations void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, int tp_size, int num_max_streams, int gemm_priority, - int comm_priority, int num_comm_sm, bool use_ce, - bool aggregate_ag); + int comm_priority, int num_comm_sm, bool use_ce, bool aggregate_ag, + bool use_cublasmp = false); int GetCgemmNumMaxStreams(); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 6566ff1689e..ac5ce949e4a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -88,15 +88,13 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } -Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, - Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type bias_grad, - Result_Type pre_gelu_out, Result_Type workspace, - JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, - int64_t rhs_axis_boundary, bool lhs_transposed, - bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, - bool use_split_accumulator, JAXX_Collective_Op collective_op) { +Error_Type CollectiveGemmInitFFI( + Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, + Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, + bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op, bool use_cublasmp) { nvte_cublas_handle_init(); // Init UB buffer @@ -123,7 +121,7 @@ Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buf buffer_shape[1] = out_shape[1]; } auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype, - collective_op); + collective_op, use_cublasmp); } return ffi_with_cuda_error_check(); } @@ -151,7 +149,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, .Attr("fuse_gelu") .Attr("grad") .Attr("use_split_accumulator") - .Attr("collective_op")); + .Attr("collective_op") + .Attr("use_cublasmp")); Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, @@ -159,7 +158,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, - bool use_split_accumulator, JAXX_Collective_Op collective_op) { + bool use_split_accumulator, JAXX_Collective_Op collective_op, + bool use_cublasmp) { // cuBLAS workspace + 256 alignment enforcement (+ swizzle scales) uint8_t *lhs_swizzle_scale_ptr = nullptr, *rhs_swizzle_scale_ptr = nullptr; auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); @@ -279,7 +279,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, ", out_shape[1]=", out_shape[1]); auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor( - buffer_shape, buffer_dtype, collective_op); + buffer_shape, buffer_dtype, collective_op, use_cublasmp); if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype); // Prepare the auxiliary buffer for the reduce-scattered GEMM output @@ -337,7 +337,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_gelu") .Attr("grad") .Attr("use_split_accumulator") - .Attr("collective_op"), + .Attr("collective_op") + .Attr("use_cublasmp"), FFI_CudaGraph_Traits); size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc1..b40280c4b88 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -49,6 +49,7 @@ #include #include #include +#include #include #include "c10/util/ArrayRef.h" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 80479dccf48..158f2b94bde 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -526,6 +526,19 @@ class CommOverlapHelper : torch::CustomClassHolder { ExtComm comm); void ub_barrier(ExtComm comm); + + int64_t get_nccl_comm_ptr(std::string comm_name) { +#ifdef USE_C10_NCCL + NVTE_CHECK(backend_is_nccl, + "Comm+GEMM overlap with cuBLASMp backend requires a tensor-parallel process ", + "group with NCCL backend."); + c10d::ProcessGroupNCCL *nccl_pg = reinterpret_cast(pgs[comm_name]); + return nccl_pg->getCommPtr(); +#else + NVTE_ERROR("Internal TE Error: CommOverlapHelper::get_nccl_comm_ptr() is an internal API that ", + "should only be used when TE is built with the NVTE_WITH_CUBLASMP=1 flag."); +#endif + } }; class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { @@ -537,6 +550,11 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); + CommOverlap(CommOverlapHelper *helper, int tp_rank, int tp_size, int num_comm_sm = 16, + bool atomic_gemm = false) + : CommOverlapBase(helper->get_nccl_comm_ptr("intra"), tp_rank, tp_size, num_comm_sm, + atomic_gemm) {} + ~CommOverlap() {} void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); @@ -558,6 +576,11 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, bool aggregate = false); + CommOverlapP2P(CommOverlapHelper *helper, int tp_rank, int tp_size, int num_comm_sm = 16, + bool atomic_gemm = false) + : CommOverlapP2PBase(helper->get_nccl_comm_ptr("intra"), tp_rank, tp_size, num_comm_sm, + atomic_gemm) {} + ~CommOverlapP2P() {} void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d0f450bc712..722010f8fad 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -490,6 +490,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) + .def(py::init(), py::arg("helper"), + py::arg("tp_rank"), py::arg("tp_size"), py::arg("num_comm_sm") = 0, + py::arg("atomic_gemm") = false, py::call_guard()) .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false, @@ -508,6 +511,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) + .def(py::init(), py::arg("helper"), + py::arg("tp_rank"), py::arg("tp_size"), py::arg("num_comm_sm") = 0, + py::arg("atomic_gemm") = false, py::call_guard()) .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index acf92332817..c9ea961caab 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -100,6 +100,7 @@ def initialize_ub( dtype: torch.dtype = torch.bfloat16, ub_cfgs: Optional[Union[dict, List[dict]]] = None, bootstrap_backend: Union[str, torch.distributed.Backend] = None, + with_cublasmp: bool = False, ) -> None: r""" Initialize the Userbuffers communicator for overlapping tensor-parallel communications with @@ -152,6 +153,10 @@ def initialize_ub( not available. Setting ``NVTE_UB_WITH_MPI=1`` when building TE overrides this option and always initializes Userbuffers with direct MPI calls in C++, which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time. + with_cublasmp : bool = False + Whether to use cuBlasMp for the all-gather and reduce-scatter overlaps. TE must + be compiled with `NVTE_WITH_CUBLASMP=1` for this option to work. + """ if not tex.device_supports_multicast(): assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( @@ -379,38 +384,56 @@ def add_ub( else dtype ) if method == "ring_exchange": - ub_obj = tex.CommOverlapP2P( - shape, # Communication buffer shape - buffer_dtype, # Communication buffer data type - helper, # Helper for torch.distributed callbacks during bootstrapping - tp_size, # Tensor-parallel group size (may be different than local_size) - tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, - num_max_streams=_NUM_MAX_UB_STREAMS, - comm_cga_size=cga_size, - num_comm_sm=num_sm, - set_sm_margin=set_sm_margin, - atomic_gemm=atomic_gemm, - use_ce=use_ce, - aggregate=aggregate, - gemm_priority=gemm_priority, - comm_priority=comm_priority, - ) + if with_cublasmp: + ub_obj = tex.CommOverlapP2P( + helper, + local_rank, + tp_size, + num_comm_sm=num_sm, + atomic_gemm=atomic_gemm, + ) + else: + ub_obj = tex.CommOverlapP2P( + shape, # Communication buffer shape + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping + tp_size, # Tensor-parallel group size (may be different than local_size) + tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, + use_ce=use_ce, + aggregate=aggregate, + gemm_priority=gemm_priority, + comm_priority=comm_priority, + ) else: - ub_obj = tex.CommOverlap( - shape, # Communication buffer shape - buffer_dtype, # Communication buffer data type - helper, # Helper for torch.distributed callbacks during bootstrapping - tp_size, # Tensor-parallel group size (may be different than local_size) - num_splits=num_splits, - num_max_streams=_NUM_MAX_UB_STREAMS, - comm_cga_size=cga_size, - num_comm_sm=num_sm, - set_sm_margin=set_sm_margin, - atomic_gemm=atomic_gemm, - gemm_priority=gemm_priority, - comm_priority=comm_priority, - rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, - ) + if with_cublasmp and method != "bulk": + ub_obj = tex.CommOverlap( + helper, + local_rank, + tp_size, + num_comm_sm=num_sm, + atomic_gemm=atomic_gemm, + ) + else: + ub_obj = tex.CommOverlap( + shape, # Communication buffer shape + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping + tp_size, # Tensor-parallel group size (may be different than local_size) + num_splits=num_splits, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, + gemm_priority=gemm_priority, + comm_priority=comm_priority, + rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, + ) _ub_communicators[(name, quantization_mode)] = ub_obj for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs): @@ -480,6 +503,9 @@ def fill_userbuffers_buffer_for_all_gather( tensor's metadata, e.g. scaling factors. """ + # cuBlasMp already handles its own buffer filling and quantization factors + if comm.with_cublasmp(): + return local_tensor, local_tensor # Tensor dimensions local_shape = local_tensor.size()