Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def setup_pytorch_extension(
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")

if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))):
# Creating a cuBlasMp context requires direct access to the underlying NCCL
# communicator in a tensor-parallel process group. The header for ProcessGroupNCCL
# needs this CPP directive to be included properly.
cxx_flags.append("-DUSE_C10D_NCCL")

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
Expand Down
7 changes: 7 additions & 0 deletions examples/jax/collective_gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def _initialize_distributed(args):
num_devices_per_process=devices_per_process,
process_id=args.process_id,
tensor_parallel_size=args.tensor_parallel_size,
use_cublasmp=args.use_cublasmp,
)


Expand Down Expand Up @@ -241,5 +242,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para
parser.add_argument(
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
)
parser.add_argument(
"--use-cublasmp",
action="store_true",
default=False,
help="Use the cuBLASMp backend for overlapping collective operations with GEMM computation",
)

return parser
84 changes: 46 additions & 38 deletions examples/jax/collective_gemm/run_test_cgemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,50 +65,58 @@ for TEST_FILE in "${TEST_FILES[@]}"; do
# Clear PIDs array for this test file
PIDS=()

for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_FILE}_gpu_${i}.log"

if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
PYTEST_ARGS=(
"-vs"
"-c $TE_PATH/tests/jax/pytest.ini"
"$TE_PATH/examples/jax/collective_gemm/$TEST_FILE"
"--num-processes=$NUM_GPUS"
)

BACKENDS=("cublasmp" "userbuffers")
for backend in "${BACKENDS[@]}"; do
for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_FILE}_gpu_${i}_${backend}.log"

if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 with ${backend} ==="
pytest --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
"${PYTEST_ARGS[@]}" \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest "${PYTEST_ARGS[@]}" \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done
done

# Wait for all processes to finish
wait

# Check and print the log content from process 0 (now has log file thanks to tee)
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE SKIPPED"
elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE PASSED"
else
echo "... $TEST_FILE INVALID"
HAS_FAILURE=1
fi

# Remove the log files after processing them
wait
rm ${TEST_FILE}_gpu_*.log
for backend in "${BACKENDS[@]}"; do
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0_${backend}.log"; then
echo "... $TEST_FILE SKIPPED for ${backend} backend"
elif grep -q "FAILED" "${TEST_FILE}_gpu_0_${backend}.log"; then
echo "... $TEST_FILE FAILED for ${backend} backend"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_FILE}_gpu_0_${backend}.log"; then
echo "... $TEST_FILE PASSED for ${backend} backend"
else
echo "... $TEST_FILE INVALID for ${backend} backend"
HAS_FAILURE=1
fi

# Remove the log files after processing them
wait
rm ${TEST_FILE}_gpu_*.log
done
done

wait
Expand Down
92 changes: 61 additions & 31 deletions tests/pytorch/distributed/run_gemm_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA graphs."
)
parser.add_argument(
"--use-cublasmp", action="store_true", default=False, help="Use cuBLASMp backend."
)
parser.add_argument(
"-v", "--verbose", action="store_true", default=False, help="Verbose info messages."
)
Expand Down Expand Up @@ -203,6 +206,7 @@ def _main(opts):
capture_output=True,
text=True,
shell=True,
check=False,
)

if result.stdout == "0": # Extra checks for non-MNNVL platforms
Expand Down Expand Up @@ -323,47 +327,73 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
):
buffer_dtype = torch.uint8
ub_obj = (
tex.CommOverlapP2P(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
opts.comm_type,
set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic,
atomic_gemm=opts.atomic,
aggregate=opts.aggregate,
use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
(
tex.CommOverlapP2P(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
opts.comm_type,
set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic,
atomic_gemm=opts.atomic,
aggregate=opts.aggregate,
use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
)
if not opts.use_cublasmp
else tex.CommOverlapP2P(
helper,
tp_rank,
tp_size,
atomic_gemm=opts.atomic,
)
)
if opts.p2p
else tex.CommOverlap(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
atomic_gemm=opts.atomic,
else (
tex.CommOverlap(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
atomic_gemm=opts.atomic,
)
if not opts.use_cublasmp
else tex.CommOverlap(
helper,
tp_rank,
tp_size,
atomic_gemm=opts.atomic,
)
)
)

# Numerical check on AG + atomic GEMM requires testing an AG+RS pair
ub_obj2 = None
if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics:
ub_obj2 = (
tex.CommOverlapP2P(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
tex.CommOverlapType.RS,
set_sm_margin=True,
atomic_gemm=True,
(
tex.CommOverlapP2P(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
tex.CommOverlapType.RS,
set_sm_margin=True,
atomic_gemm=True,
)
if not opts.use_cublasmp
else tex.CommOverlapP2P(helper, tp_rank, tp_size, atomic_gemm=True)
)
if opts.atomic_rs_p2p
else tex.CommOverlap(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
atomic_gemm=True,
else (
tex.CommOverlap(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
atomic_gemm=True,
)
if not opts.use_cublasmp
else tex.CommOverlap(helper, tp_rank, tp_size, atomic_gemm=True)
)
)

Expand Down Expand Up @@ -408,7 +438,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
mean=0.0,
std=opts.std,
)
if ub_obj2 is not None:
if opts.comm_type == tex.CommOverlapType.AG and ub_obj2 is not None:
kernel2_t = torch.nn.init.normal_(
torch.empty(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"),
mean=0.0,
Expand Down
7 changes: 7 additions & 0 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ def _parse_args(argv=None, namespace=None):
default=0,
help="Number of layers at the end to run in bf16.",
)
parser.add_argument(
"--use-cublasmp",
action="store_true",
default=False,
help="Use cuBLASMp backend.",
)
args = parser.parse_args(argv, namespace)

if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
Expand Down Expand Up @@ -436,6 +442,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
with_cublasmp=opts.use_cublasmp,
)

with te.quantized_model_init(enabled=opts.fp8_init):
Expand Down
Loading