diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 4056a5fa7..c0cc26919 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -98,14 +98,6 @@ def setup_pytorch_extension( if version < (12, 0): raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") - if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" - mpi_path = Path(os.getenv("MPI_HOME")) - include_dirs.append(mpi_path / "include") - cxx_flags.append("-DNVTE_UB_WITH_MPI") - library_dirs = [] libraries = [] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): @@ -119,12 +111,22 @@ def setup_pytorch_extension( cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", 0))): - cxx_flags.append("-DNVTE_ENABLE_ROCSHMEM") mpi_home = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi")) include_dirs.append(mpi_home / "include") library_dirs.append(mpi_home / "lib") - libraries.append("mpi_cxx") + libraries.append("mpi") + cxx_flags.extend(["-DNVTE_ENABLE_ROCSHMEM", "-DOMPI_SKIP_MPICXX"]) + extra_link_args = [] + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" + mpi_path = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi")) + include_dirs.append(mpi_path / "include") + library_dirs.append(mpi_path / "lib") + libraries.append("mpi") + cxx_flags.extend(["-DNVTE_UB_WITH_MPI", "-DOMPI_SKIP_MPICXX"]) # Construct PyTorch CUDA extension sources = [str(path) for path in sources] @@ -138,4 +140,5 @@ def setup_pytorch_extension( extra_compile_args={"cxx": cxx_flags}, libraries=[str(lib) for lib in libraries], library_dirs=[str(lib_dir) for lib_dir in library_dirs], + extra_link_args=[str(arg) for arg in extra_link_args], ) diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index e510df176..1fd40305c 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None): ) parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument( - "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + "--fp8", action="store_true", default=False, help="Enables the te.autocast() context." ) parser.add_argument( "--no-comm-overlap", @@ -263,7 +263,13 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) te.module.base.initialize_ub( [batched_size, hidden_size], tp_size, - use_fp8=opts.fp8, + quantization_modes=[ + ( + te.module.base.UserBufferQuantizationMode.FP8 + if opts.fp8 + else te.module.base.UserBufferQuantizationMode.NONE + ) + ], dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, ) @@ -293,7 +299,7 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) dist_print(" |-- Forward pass", group=tp_group, debug=True) with torch.amp.autocast("cuda", dtype=torch.bfloat16): - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world): y = model(x) if isinstance(y, tuple): out, *_ = y diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py new file mode 100644 index 000000000..ddc848229 --- /dev/null +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py @@ -0,0 +1,504 @@ +#!/usr/bin/python3 + +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import socket +import fcntl +import struct +import argparse +import warnings + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel + +import torch.profiler + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex +from transformer_engine.common.recipe import Format, DelayedScaling + +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + + +def _te_layer_argtype(name): + te_layers = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers)) + if name.lower() not in layer_map.keys(): + raise argparse.ArgumentTypeError( + f"Invalid TE layer name! Please choose from: {layer_map.keys()}" + ) + return layer_map[name.lower()] + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser( + description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers." + ) + parser.add_argument( + "-i", "--num-iters", type=int, default=10, help="Number of dummy 'training' iterations." + ) + parser.add_argument("-b", "--batch-size", type=int, default=8, help="Input batch size.") + parser.add_argument("-s", "--seq-length", type=int, default=16384, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=64, help="Number of attention heads." + ) + parser.add_argument( + "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." + ) + parser.add_argument( + "--layer-type", + type=_te_layer_argtype, + default=te.TransformerLayer, + help="Transformer Engine layer to train with comm+GEMM overlap.", + ) + parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") + parser.add_argument( + "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + ) + parser.add_argument( + "--no-comm-overlap", + action="store_true", + default=False, + help="Disable the comm+GEMM overlap.", + ) + parser.add_argument( + "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." + ) + parser.add_argument( + "--tcp-init", + action="store_true", + default=False, + help="Initialize torch.distributed with TcpStore.", + ) + parser.add_argument( + "--bind-to-device", + action="store_true", + default=False, + help="Initialize torch.distributed with `device_id` to bind each rank to a single device.", + ) + parser.add_argument( + "--bootstrap-backend", + type=str.lower, + default="nccl", + choices=["gloo", "mpi", "nccl"], + help="Communications backend for host tensor collectives during Userbuffers bootstrapping.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="Print out from every rank instead of just the root rank of relevant process groups.", + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Print out additional debug information.", + ) + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Enable PyTorch profiler.", + ) + parser.add_argument( + "--profile-dir", + type=str, + default="./logs/profiler_traces", + help="Directory to save PyTorch profiler traces.", + ) + parser.add_argument( + "--ub_config", + type=str, + default="./ub_config.json", + help="Userbuffer configuration file.", + ) + + args = parser.parse_args(argv, namespace) + if args.bootstrap_backend == "nccl": + args.bind_to_device = True + return args + + +def _get_layer_args(config, tp_group, tp_size, reference=False): + hidden_size = config.num_heads * config.head_dim + input_shape = [config.seq_length, config.batch_size, hidden_size] + args = [hidden_size] + kwargs = { + "params_dtype": torch.float32, + "device": "cuda", + "tp_group": tp_group, + "tp_size": tp_size, + "sequence_parallel": True, + } + kwargs["ub_overlap_ag"] = not config.no_comm_overlap + + if config.layer_type is te.Linear: + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["parallel_mode"] = "row" + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + kwargs["ub_name"] = "proj" + else: + input_shape[0] = config.seq_length // tp_size + kwargs["ub_bulk_wgrad"] = not config.no_comm_overlap + kwargs["ub_bulk_dgrad"] = not config.no_comm_overlap + if config.layer_type is te.LayerNormLinear: + args.append(3 * hidden_size) + kwargs["parallel_mode"] = "column" + kwargs["ub_name"] = "qkv" + else: + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + # args.append(4 * hidden_size) + args.append(int(3.5 * hidden_size)) + + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap + kwargs["hidden_dropout"] = 0.0 + + return args, kwargs, input_shape + +def create_ub_cfgs(config_file: str, tp_size: int = 8): + import json + with open(config_file, 'r') as f: + data = json.load(f) + cfgs = {} + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None + layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] + layers_all_gather_overlap = [ + "qkv_fprop", + "qkv_dgrad", + "proj_dgrad", + "fc1_fprop", + "fc1_dgrad", + "fc2_dgrad", + ] + + for name, method in data.items(): + if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() + + cfg = { + "method": method, + "is_reduce_scatter": name in layers_reduce_scatter_overlap, + "num_sm": 1 if method in ["ring_exchange", "recursive_doubling"] else 16, + "cga_size": 1 if method in ["ring_exchange", "recursive_doubling"] else 2, + "set_sm_margin": False, + "num_splits": 4 if method == "pipeline" else tp_size, + "aggregate": False, + "atomic_gemm": False, + "use_ce": True, + "fp8_buf": name in layers_all_gather_overlap, + "comm_priority": _MAX_STREAM_PRIORITY, + "gemm_priority": _MIN_STREAM_PRIORITY, + } + + cfgs[name] = cfg + + return cfgs + +def _train(opts): + if "OMPI_COMM_WORLD_SIZE" in os.environ: + # Execution with `mpirun -np N` + WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) + WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) + opts.tcp_init = True + opts.bind_to_device = True + opts.bootstrap_backend = "mpi" + elif "TORCHELASTIC_RUN_ID" in os.environ: + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + else: + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + NUM_NODES = WORLD_SIZE // LOCAL_SIZE + + # Initialize torch.distributed global process group and get DP/TP groups + torch.cuda.set_device(LOCAL_RANK) + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + if opts.tcp_init or NUM_NODES > 1: + if NUM_NODES > 1: + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())) + MASTER_PORT = os.getenv("MASTER_PORT", "1234") + dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}" + if opts.bind_to_device or opts.bootstrap_backend == "nccl": + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + + def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False): + if debug and not opts.debug: + return + group_rank = dist.get_rank(group) + stream = sys.stderr if error else sys.stdout + if group_rank == src: + stream.write(f"[rank{WORLD_RANK}] {msg}{end}") + dist.barrier(group) + + dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") + + # Figure out process groups for tensor- and data-parallelism (if any) + if NUM_NODES > 1: + # Create a list of world ranks on this node + hostname = socket.gethostname() + ifname = os.getenv( + "NVTE_UB_SOCKET_IFNAME", + os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), + ) + + if ifname is not None: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + hostname = socket.inet_ntoa( + fcntl.ioctl( + s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) + )[20:24] + ) + except OSError as err: + raise OSError(f"Invalid network interface: {ifname}") from err + + hostnames = [None for _ in range(WORLD_SIZE)] + dist.all_gather_object(hostnames, hostname) + unique_hosts = [] + for host in hostnames: + if host not in unique_hosts: + unique_hosts.append(host) + assert len(unique_hosts) == NUM_NODES + + ranks_per_node_list = [[] for _ in range(NUM_NODES)] + self_node_idx = -1 + for i, host in enumerate(hostnames): + node_idx = unique_hosts.index(host) + ranks_per_node_list[node_idx].append(i) + if host == hostname: + self_node_idx = node_idx + assert self_node_idx >= 0 + self_node_ranks = ranks_per_node_list[self_node_idx] + + if opts.num_replicas > 1: + # Split node ranks into multiple replicas + assert len(self_node_ranks) % opts.num_replicas == 0 + tp_size = len(self_node_ranks) // opts.num_replicas + ranks_per_replica_list = [] + for node_ranks in ranks_per_node_list: + for i in range(opts.num_replicas): + start = i * tp_size + end = start + tp_size + ranks_per_replica_list.append(node_ranks[start:end]) + + self_replica_idx = -1 + for i, replica_ranks in enumerate(ranks_per_replica_list): + if WORLD_RANK in replica_ranks: + self_replica_idx = i + break + assert self_replica_idx >= 0 + + else: + # The entire node is the tensor-parallel group + ranks_per_replica_list = ranks_per_node_list + self_replica_idx = self_node_idx + + tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") + ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) + + else: + if opts.num_replicas > 1: + # Mixed data- and tensor-parallelism on a single node + # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions + all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") + ranks_per_replica_tensor = all_ranks.reshape( + (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) + ) + tp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.tolist(), backend="nccl" + ) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) + else: + dp_group = None + tp_group = nccl_world + + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + dist_print( + f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", + group=tp_group, + ) + if dp_group is not None: + dp_rank = dist.get_rank(dp_group) + dist_print( + f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", + group=dp_group, + ) + else: + dp_rank = 0 + + # Intialize userbuffers + hidden_size = opts.num_heads * opts.head_dim + batched_size = opts.seq_length * opts.batch_size + if not opts.no_comm_overlap: + te.module.base.initialize_ub( + [batched_size, hidden_size], + tp_size, + use_fp8=opts.fp8, + dtype=torch.bfloat16, + bootstrap_backend=opts.bootstrap_backend, + ub_cfgs=create_ub_cfgs(opts.ub_config, tp_size) + ) + # Initialize the fused LayerNorm + Multi-layer Perceptron module + torch.manual_seed(opts.seed + dp_rank) + torch.cuda.manual_seed(opts.seed + tp_rank) + layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size) + model = opts.layer_type(*layer_args, **layer_kwargs) + if dp_group is not None: + model = DistributedDataParallel(model, dim=1, process_group=dp_group) + + # Initialize optimizer with model parameters + optim = torch.optim.Adam(model.parameters(), lr=0.0001) + + # Fp8 recipe setup + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + + if opts.profile: + log_dir = os.path.join(opts.profile_dir, f"rank_{WORLD_RANK}") + os.makedirs(log_dir, exist_ok=True) + dist_print(f"Profiler traces will be saved to: {log_dir}", group=nccl_world) + + schedule = torch.profiler.schedule(wait=1, warmup=2, active=5, repeat=1) + + on_trace_ready = torch.profiler.tensorboard_trace_handler( + log_dir, worker_name=f"rank_{WORLD_RANK}" + ) + + profiler_activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + import time + + start_time = time.time() + with torch.profiler.profile( + schedule=schedule, + # record_shapes=True, + # with_stack=True, + # with_flops=True, + # with_modules=True, + on_trace_ready=on_trace_ready, + profile_memory=True, + activities=profiler_activities, + ) as prof: + dist_print("Starting training iterations...") + for i in range(opts.num_iters): + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) + loss.backward() + + dist_print(" |-- Optimizer step", group=tp_group, debug=True) + optim.step() + + prof.step() + torch.cuda.synchronize() + end_time = time.time() + total_wall_clock_time = end_time - start_time + print(f"Total Wall Clock Time: {total_wall_clock_time:.4f} seconds") + # total_flops = sum([item.flops for item in prof.key_averages()]) + # print(f"Total FLOPs: {total_flops}") + else: + dist_print("Starting training iterations...") + for i in range(opts.num_iters): + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) + loss.backward() + + dist_print(" |-- Optimizer step", group=tp_group, debug=True) + optim.step() + + + dist_print("Finished training!") + te.module.base.destroy_ub() + + dist_print("Destroying all process groups...", debug=True) + dist.destroy_process_group() + if opts.debug and WORLD_RANK == 0: + print("Exiting...\n", end="", flush=True) + + return 0 + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) \ No newline at end of file diff --git a/examples/pytorch/comm_gemm_overlap/ub_config.json b/examples/pytorch/comm_gemm_overlap/ub_config.json new file mode 100644 index 000000000..c6d807f98 --- /dev/null +++ b/examples/pytorch/comm_gemm_overlap/ub_config.json @@ -0,0 +1,14 @@ +{ + "qkv_fprop": "recursive_doubling", + "fc1_fprop": "recursive_doubling", + "fc2_dgrad": "recursive_doubling", + + "proj_fprop": "recursive_doubling", + "fc2_fprop": "recursive_doubling", + + "qkv_dgrad": "bulk", + "qkv_wgrad": "bulk", + "fc1_dgrad": "bulk", + "fc1_wgrad": "bulk" + +} \ No newline at end of file diff --git a/hipify_custom_map.json b/hipify_custom_map.json index 8773c233e..5d4467283 100644 --- a/hipify_custom_map.json +++ b/hipify_custom_map.json @@ -5,7 +5,14 @@ "util/cuda_runtime.h" : "util/hip_runtime.h", "ATen/cudnn/Handle.h" : "ATen/miopen/Handle.h", "CUfunc_cache" : "hipFuncCache_t", - "" : "" + "" : "", + "cudaLaunchKernel": "hipLaunchKernel", + "CUmemGenericAllocationHandle": "hipMemGenericAllocationHandle_t", + "\"cuda_runtime.h\"": "\"hip_runtime.h\"", + "cudaLaunchConfig_t": "hipLaunchConfig_t", + "cudaLaunchAttribute": "hipLaunchAttribute", + "cudaLaunchAttributeCooperative": "hipLaunchAttributeCooperative", + "CUdeviceptr": "hipDeviceptr_t" } } diff --git a/setup.py b/setup.py index b28644e03..a3ccb28bf 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,12 @@ def run(self): def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" cmake_flags = [] + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" + cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") + if rocm_build(): cmake_flags.append("-DUSE_ROCM=ON") if os.getenv("NVTE_AOTRITON_PATH"): @@ -85,11 +91,6 @@ def setup_common_extension() -> CMakeExtension: else: cmake_flags.append("-DUSE_ROCM=OFF") cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)] - if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" - cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))): assert ( diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 8638c1bce..d2259853d 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -1,5 +1,7 @@ #!/usr/bin/python3 +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -28,6 +30,11 @@ warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +import transformer_engine.pytorch.cpp_extensions as tex +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + class multi_module_model(torch.nn.Module): def __init__(self, module, num_layers, *args, **kwargs): diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index bdbc97517..6d8470a48 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -162,7 +162,11 @@ list(APPEND transformer_engine_SOURCES fused_router/fused_topk_with_score_function.cu recipe/current_scaling.cu recipe/delayed_scaling.cu - recipe/fp8_block_scaling.cu) + recipe/fp8_block_scaling.cu + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/userbuffers/userbuffers.cu + comm_gemm_overlap/comm_gemm_overlap.cpp) if(USE_CUDA) # Removed indent to minimize code diff with NV upstream # Files unique in cuda building @@ -175,11 +179,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn_fp8.cu fused_attn/fused_attn.cpp fused_attn/utils.cu - util/cuda_nvml.cpp - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) + util/cuda_nvml.cpp) add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) else() list(APPEND transformer_engine_SOURCES @@ -188,7 +188,8 @@ else() fused_attn_rocm/fused_attn_ck.cpp fused_attn_rocm/utils.cpp gemm/rocm_gemm.cu - amd_detail/system.cpp) + amd_detail/system.cpp + comm_gemm_overlap/rocm_comm_gemm_overlap.cpp) # process source code files set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..) @@ -235,27 +236,25 @@ target_link_libraries(transformer_engine PUBLIC CUDNN::cudnn_all) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") + target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +endif() # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI +# Changed option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) if (NVTE_UB_WITH_MPI) - find_package(MPI REQUIRED) - target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) - target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) + # OpenMPI C++ headers are deprecated -- flag unused w/ MPICH + add_definitions(-DOMPI_SKIP_MPICXX) + + target_include_directories(transformer_engine PRIVATE "$ENV{MPI_HOME}/include") + target_link_directories(transformer_engine PRIVATE "$ENV{MPI_HOME}/lib") + target_link_libraries(transformer_engine PUBLIC mpi) target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) endif() - -option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF) -if (NVTE_ENABLE_NVSHMEM) - add_subdirectory(nvshmem_api) - target_link_libraries(transformer_engine PUBLIC nvshmemapi) - target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) -endif() - -# Hack to enable dynamic loading in cuDNN frontend -target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) - + +if (USE_CUDA) + # Hack to enable dynamic loading in cuDNN frontend + target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) else() option(NVTE_ENABLE_ROCSHMEM "Compile with ROCSHMEM library" OFF) if (NVTE_ENABLE_ROCSHMEM) @@ -397,7 +396,7 @@ endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") message(STATUS "Threads per parallel build job: ${BUILD_THREADS_PER_JOB}") else() - set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3") + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3 -fopenmp") set(HIP_HCC_FLAGS "${CMAKE_HIP_FLAGS} -mavx2 -mf16c -mfma -std=c++17") # Ask hcc to generate device code during compilation so we can use # host linker to link. diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 40595ea98..3e9fcf1dd 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -21,6 +21,12 @@ #define HALF_BYTES 2 #define UB_MAX_SM 32 +#ifdef __HIP_PLATFORM_AMD__ +#define half_dtype hip_bfloat16 +#define __nv_fp8_e5m2 te_hip_fp8_e5m2 +#define __nv_fp8_e4m3 te_hip_fp8_e4m3 +#endif + using namespace std::placeholders; namespace transformer_engine { @@ -64,6 +70,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl #endif _comm_created = true; } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { _use_ce = static_cast(use_ce); _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; @@ -74,7 +89,7 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl _gemm_priority = gemm_priority; _comm_priority = comm_priority; } - for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { + for (int i = 0; i < std::max(num_max_streams, num_splits); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority)); _stream_compute.push_back(std::move(stream)); @@ -101,10 +116,10 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl DType::kInt32); } // CUDA event creation - cudaEventCreateWithFlags(&_start_compute, 0); - cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_comm, 0); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_compute, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_compute, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_comm, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_comm, 0)); /* Defining the launcher order between the communication and GEMM kernels @@ -114,11 +129,11 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl */ int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); int runtime_version = 0; - cudaRuntimeGetVersion(&runtime_version); + NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&runtime_version)); cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, 0)); if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { - cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming)); } else { _comm_launch_event = 0; } @@ -129,18 +144,34 @@ CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_start_comm); cudaEventDestroy(_stop_compute); cudaEventDestroy(_start_compute); - if (_comm_launch_event) cudaEventDestroy(_comm_launch_event); + if (_comm_launch_event) { + cudaEventDestroy(_comm_launch_event); + } + + if (_atomic_gemm) { + cudaFree(_counter.dptr()); + } - if (_atomic_gemm) cudaFree(_counter.dptr()); + for (size_t i = 0; i < _stream_compute.size(); i++) { + cudaStreamSynchronize(_stream_compute[i]); + cudaStreamDestroy(_stream_compute[i]); + } - for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); + auto error = cudaGetLastError(); + if (error != cudaSuccess) { + NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); + } if (_comm_created) { + try { #ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); + destroy_communicator_mpi(_ub_comm); #else - destroy_communicator(_ub_comm); + destroy_communicator(_ub_comm); #endif + } catch (const std::exception &e) { + NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); + } _comm_created = false; } } @@ -262,6 +293,11 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm) { _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, @@ -272,7 +308,9 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Register UBuf %d\n", _ub_reg); + } _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); NVTE_CHECK_CUDA( @@ -282,6 +320,7 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType CommOverlapBase::~CommOverlapBase() { cudaEventDestroy(_start_d2dcopy); + cudaStreamSynchronize(_stream_comm); cudaStreamDestroy(_stream_comm); } @@ -295,6 +334,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main) { + printf("bulk_overlap\n"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -320,7 +360,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, comm_elements, _ub_comm, _stream_comm, - (cudaEvent_t)_comm_launch_event); + (cudaEvent_t)_comm_launch_event); } else { reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); @@ -352,6 +392,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + printf("atomic_gemm_overlap_rs\n"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -448,6 +489,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + printf("split_overlap_rs\n"); // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -584,6 +626,31 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); } // CommOverlapBase::split_overlap_rs +void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) { + printf("bulk_overlap_external_ag\n"); + + int comm_bytes = _ubuf.bytes(); + int comm_bytes_per_rank = comm_bytes / _tp_size; + + // We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush + userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, send_stream); + userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, recv_stream); + + // We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf + for (auto stream : {send_stream, recv_stream}) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0)); + } + + // Next we sync with the main stream + // We have to recapture an event off the comm stream to enable cuda graph capture otherwise the comm stream will be never be joined in the graph + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} + /*************************************************************************************************** * Comm+GEMM Overlap P2P Base (Ring-Exchange) **************************************************************************************************/ @@ -595,40 +662,46 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, CommOverlapType comm_type, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, - bool atomic_gemm, bool aggregate) + bool atomic_gemm, bool aggregate, bool use_rd) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate, use_rd); +} + +void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate, bool use_rd) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; + _use_rd = use_rd; // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); - int buffer_chunk_bytes = buffer_bytes / tp_size; - _num_ubuf_chunks = tp_size; + int buffer_chunk_bytes = buffer_bytes / _tp_size; + _num_ubuf_chunks = _tp_size; if (_is_reduce_scatter) { // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk // outputs for reduction at the end of the pipelining. - buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); - _num_ubuf_chunks = tp_size * 2 - 1; + buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1); + _num_ubuf_chunks = _tp_size * 2 - 1; } void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg); _ubuf = TensorWrapper( buffer_ptr, - std::vector{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]}, buffer_dtype); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); for (int i = 0; i < _num_ubuf_chunks; i++) { _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), - std::vector{buffer_shape[0] / tp_size, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]}, buffer_dtype)); ubuf_byte_ptr += buffer_chunk_bytes; } @@ -651,22 +724,74 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + for (int i = 0; i < _stream_compute.size(); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); } + for (int i = 0; i < 7; i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + l_stream_send.push_back(std::move(stream)); + } + for (int i = 0; i < 7; i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + l_stream_recv.push_back(std::move(stream)); + } NVTE_CHECK_CUDA( cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); + for (int i = 0; i < 7; i++) { + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&l_stop_recv[i], 0)); + } } CommOverlapP2PBase::~CommOverlapP2PBase() { cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_send); cudaStreamDestroy(_stream_recv); - for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); + for (size_t i = 0; i < _stream_send.size(); i++) { + cudaStreamDestroy(_stream_send[i]); + } + for (int i = 0; i < 7; i++) { + cudaStreamDestroy(l_stream_recv[i]); + cudaStreamDestroy(l_stream_send[i]); + cudaEventDestroy(l_stop_recv[i]); + } +} + +void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == source_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); } TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, @@ -693,6 +818,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { + printf("atomic_gemm_overlap_ag\n"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -795,6 +921,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { + printf("split_overlap_ag\n"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -809,6 +936,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const bool do_gelu = pre_gelu_out.numel() > 0; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); @@ -877,12 +1013,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } else { @@ -930,16 +1060,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } + // Copy all-gathered B from communication buffer into auxiliary output + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + _ub_comm->sms = ori_sms; for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); @@ -959,6 +1089,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + printf("atomic_gemm_overlap_rs\n"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -1023,6 +1154,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + printf("split_overlap_rs\n"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; diff --git a/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp new file mode 100644 index 000000000..6c3f02f6f --- /dev/null +++ b/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp @@ -0,0 +1,485 @@ +/************************************************************************* + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_driver.h" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" +#include "common/util/system.h" +#include "userbuffers/userbuffers.h" +#include + +static int strides[4] = {1,3,5,7}; + +namespace transformer_engine { + +/*void CommOverlapP2PBase::rocm_split_overlap_ag_rd_old(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + printf("rocm_split_overlap_ag_rd\n"); + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + const int comm_bytes = _ubufs[0].bytes(); + const bool do_gelu = pre_gelu_out.numel() > 0; + const size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + int steps = 31 - __builtin_clz(_tp_size); + + // Chunk dims + std::vector input_b_chunk_shape = + (transb ? std::vector{k, n_chunk} : std::vector{n_chunk, k}); + std::vector output_chunk_shape = {n_chunk, m}; + size_t input_b_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; + + // GEMM + auto input_b_chunk = + get_buffer_chunk_like(B, input_b_chunk_size * _tp_id, input_b_chunk_shape); + auto output_chunk = + get_tensor_chunk(D, output_chunk_size * _tp_id, output_chunk_shape); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * _tp_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (_tp_id % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[_tp_id % _stream_compute.size()]); + + std::vector owned_chunks; + owned_chunks.reserve(_tp_size); + owned_chunks.push_back(_tp_id); + size_t offset = 1; + + for (int step = 0; step < steps; step++) { + int send_rank = (_tp_id + offset) % _tp_size; + int recv_rank = (_tp_id - offset + _tp_size) % _tp_size; + + for (int i = 0; i < owned_chunks.size(); i++) { + size_t send_offset = owned_chunks[i] * comm_bytes; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, + comm_bytes, _ub_comm, send_rank, _stream_send[i % _stream_send.size()]); + } + + std::vector new_chunks; + for (size_t i = 0; i < owned_chunks.size(); i++) { + size_t new_chunk_id = (recv_rank + i * offset) % _tp_size; + if (new_chunk_id >= _tp_size || + std::find(owned_chunks.begin(), owned_chunks.end(), new_chunk_id) != owned_chunks.end()) continue; + size_t recv_offset = new_chunk_id * comm_bytes; + size_t stream_id = new_chunks.size() % _stream_compute.size(); + + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, + comm_bytes, _ub_comm, recv_rank, _stream_recv); + + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[stream_id], _stop_recv, 0)); + + auto input_b_chunk = get_buffer_chunk_like(B, input_b_chunk_size * new_chunk_id, input_b_chunk_shape); + output_chunk = get_tensor_chunk(D, output_chunk_size * new_chunk_id, output_chunk_shape); + aux_chunk = (do_gelu) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * new_chunk_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + workspace_chunk = get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[stream_id]); + + new_chunks.push_back(new_chunk_id); + } + owned_chunks.insert(owned_chunks.end(), new_chunks.begin(), new_chunks.end()); + offset <<= 1; + } + + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); +} // rocm_split_overlap_ag_rd*/ + +void CommOverlapP2PBase::rocm_split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + printf("split_overlap_ag_multi_stride_slice_gemm\n"); + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + + const int comm_bytes = _ubufs[0].bytes(); + const int num_rings = 4; + const int strides[4] = {1, 3, 5, 7}; // Coprime strides for TP=8 + + // Each ring handles 1/4 of the data (slice) + NVTE_CHECK(comm_bytes % num_rings == 0, "Comm size must be divisible by num_rings"); + const int slice_bytes = comm_bytes / num_rings; + + // Each slice has 1/4 of the columns + NVTE_CHECK(n_chunk % num_rings == 0, "n_chunk must be divisible by num_rings"); + const size_t n_slice = n_chunk / num_rings; + + const bool do_gelu = pre_gelu_out.numel() > 0; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size()); + } + + // Sync all streams + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + for (int r = 0; r < num_rings; ++r) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_recv[r], _start_compute, 0)); + } + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + // Events to track when each slice is ready + // slice_ready[chunk_id][ring_id] + const int total_slices = _tp_size * num_rings; + std::vector slice_events(total_slices); + for (int i = 0; i < total_slices; i++) { + // cudaEventDisableTiming is critical for performance here + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&slice_events[i], cudaEventDisableTiming)); + } + + auto get_event = [&](int chunk_id, int ring_id) { + return slice_events[chunk_id * num_rings + ring_id]; + }; + + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaEventRecord(get_event(_tp_id, r), stream_main)); + } + + // Helper: Get byte offset for slice r of chunk_id + auto get_slice_offset = [&](int chunk_id, int ring_id) -> size_t { + return chunk_id * comm_bytes + ring_id * slice_bytes; + }; + + // Slice dimensions for GEMM (operates on 1/4 of chunk) + std::vector input_b_slice_shape = + (transb ? std::vector{k, n_slice} : std::vector{n_slice, k}); + std::vector output_slice_shape = {n_slice, m}; + size_t input_b_slice_elems = n_slice * k; + size_t output_slice_elems = n_slice * m; + + // GEMM launcher for individual SLICE + auto launch_slice_gemm = [&](int chunk_id, int ring_id, int step) { + + // Wait only for THIS slice to arrive + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[ring_id], + get_event(chunk_id, ring_id), 0)); + + // Calculate element offsets for this specific slice + size_t b_elem_offset = (chunk_id * n_chunk * k) + (ring_id * input_b_slice_elems); + size_t d_elem_offset = (chunk_id * n_chunk * m) + (ring_id * output_slice_elems); + + auto input_b_slice = get_buffer_chunk_like(B, b_elem_offset, input_b_slice_shape); + auto output_slice = get_tensor_chunk(D, d_elem_offset, output_slice_shape); + + auto aux_slice = (do_gelu) + ? get_tensor_chunk(pre_gelu_out, d_elem_offset, {n_slice, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + + auto workspace_chunk = get_tensor_chunk(workspace, ring_id * workspace_size_chunk, + {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_slice.data(), output_slice.data(), bias.data(), + aux_slice.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[ring_id]); + }; + + for (int step = 0; step < _tp_size; step++) { + // If step == 0, we just do local GEMM. + // If step > 0, we are doing Ring Exchange + GEMM. + + for (int r = 0; r < num_rings; r++) { + int stride = strides[r]; + int next_rank = (_tp_id + stride) % _tp_size; + int prev_rank = (_tp_id - stride + _tp_size) % _tp_size; + + // CHUNK CALCULATION + int curr_chunk_id = (_tp_size + _tp_id - (step * stride) % _tp_size) % _tp_size; + + // A. LAUNCH GEMM for the chunk we currently have + launch_slice_gemm(curr_chunk_id, r, step); + + // B. COMMUNICATE (only if not on the last step) + if (step < _tp_size - 1) { + int next_recv_chunk_id = (_tp_size + _tp_id - ((step + 1) * stride) % _tp_size) % _tp_size; + + size_t send_off = get_slice_offset(curr_chunk_id, r); + size_t recv_off = get_slice_offset(next_recv_chunk_id, r); + + // Wait for GEMM to finish reading if you are using a single buffer? + // Or wait for the slice_ready event. + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], get_event(curr_chunk_id, r), 0)); + + userbuffers_send(_ub_reg, send_off, _ub_reg, send_off, slice_bytes, _ub_comm, next_rank, l_stream_send[r]); + userbuffers_recv(_ub_reg, recv_off, _ub_reg, recv_off, slice_bytes, _ub_comm, prev_rank, l_stream_recv[r]); + + // RECORD for the next step's GEMM + NVTE_CHECK_CUDA(cudaEventRecord(get_event(next_recv_chunk_id, r), l_stream_recv[r])); + } + } + } + + // Copy all-gathered B + if (B_copy.numel() > 0) { + // Wait for all recv streams + for (int r = 0; r < num_rings; r++) { + // Find the last chunk_id this ring received + int last_chunk = (_tp_size + _tp_id - ((_tp_size - 1) * strides[r])) % _tp_size; + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[0], get_event(last_chunk, r), 0)); + } + + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, l_stream_send[0])); + } + + _ub_comm->sms = ori_sms; + + // Final sync + for (auto& s : _stream_compute) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, s)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, l_stream_send[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, l_stream_recv[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + } + + for (auto& ev : slice_events) { + NVTE_CHECK_CUDA(cudaEventDestroy(ev)); + } +} // CommOverlapP2PBase::rocm_split_overlap_ag_smr + +/*void CommOverlapP2PBase::rocm_split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + printf("split_overlap_ag_multi_stride_slice_gemm\n"); + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + + const int comm_bytes = _ubufs[0].bytes(); + const int num_rings = 4; + const int strides[4] = {1, 3, 5, 7}; // Coprime strides for TP=8 + + // Each ring handles 1/4 of the data (slice) + NVTE_CHECK(comm_bytes % num_rings == 0, "Comm size must be divisible by num_rings"); + const int slice_bytes = comm_bytes / num_rings; + + // Each slice has 1/4 of the columns + NVTE_CHECK(n_chunk % num_rings == 0, "n_chunk must be divisible by num_rings"); + const size_t n_slice = n_chunk / num_rings; + + const bool do_gelu = pre_gelu_out.numel() > 0; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size()); + } + + // Sync all streams + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + for (int r = 0; r < num_rings; ++r) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_recv[r], _start_compute, 0)); + } + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + // Events to track when each slice is ready + // slice_ready[chunk_id][ring_id] + const int total_slices = _tp_size * num_rings; + std::vector slice_events(total_slices); + for (int i = 0; i < total_slices; i++) { + // cudaEventDisableTiming is critical for performance here + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&slice_events[i], cudaEventDisableTiming)); + } + + auto get_event = [&](int chunk_id, int ring_id) { + return slice_events[chunk_id * num_rings + ring_id]; + }; + + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaEventRecord(get_event(_tp_id, r), stream_main)); + } + + // Helper: Get byte offset for slice r of chunk_id + auto get_slice_offset = [&](int chunk_id, int ring_id) -> size_t { + return chunk_id * comm_bytes + ring_id * slice_bytes; + }; + + // Slice dimensions for GEMM (operates on 1/4 of chunk) + std::vector input_b_slice_shape = + (transb ? std::vector{k, n_slice} : std::vector{n_slice, k}); + std::vector output_slice_shape = {n_slice, m}; + size_t input_b_slice_elems = n_slice * k; + size_t output_slice_elems = n_slice * m; + + // GEMM launcher for individual SLICE + auto launch_slice_gemm = [&](int chunk_id, int ring_id, int step) { + + // Wait only for THIS slice to arrive + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[ring_id], + get_event(chunk_id, ring_id), 0)); + + // Calculate element offsets for this specific slice + size_t b_elem_offset = (chunk_id * n_chunk * k) + (ring_id * input_b_slice_elems); + size_t d_elem_offset = (chunk_id * n_chunk * m) + (ring_id * output_slice_elems); + + auto input_b_slice = get_buffer_chunk_like(B, b_elem_offset, input_b_slice_shape); + auto output_slice = get_tensor_chunk(D, d_elem_offset, output_slice_shape); + + auto aux_slice = (do_gelu) + ? get_tensor_chunk(pre_gelu_out, d_elem_offset, {n_slice, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + + auto workspace_chunk = get_tensor_chunk(workspace, ring_id * workspace_size_chunk, + {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_slice.data(), output_slice.data(), bias.data(), + aux_slice.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[ring_id]); + }; + + for (int step = 0; step < _tp_size; step++) { + // If step == 0, we just do local GEMM. + // If step > 0, we are doing Ring Exchange + GEMM. + + for (int r = 0; r < num_rings; r++) { + int stride = strides[r]; + int next_rank = (_tp_id + stride) % _tp_size; + int prev_rank = (_tp_id - stride + _tp_size) % _tp_size; + + // CHUNK CALCULATION + int curr_chunk_id = (_tp_size + _tp_id - (step * stride) % _tp_size) % _tp_size; + + // A. LAUNCH GEMM for the chunk we currently have + launch_slice_gemm(curr_chunk_id, r, step); + + // B. COMMUNICATE (only if not on the last step) + if (step < _tp_size - 1) { + int next_recv_chunk_id = (_tp_size + _tp_id - ((step + 1) * stride) % _tp_size) % _tp_size; + + size_t send_off = get_slice_offset(curr_chunk_id, r); + size_t recv_off = get_slice_offset(next_recv_chunk_id, r); + + // Wait for GEMM to finish reading if you are using a single buffer? + // Or wait for the slice_ready event. + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], get_event(curr_chunk_id, r), 0)); + + userbuffers_send(_ub_reg, send_off, _ub_reg, send_off, slice_bytes, _ub_comm, next_rank, l_stream_send[r]); + userbuffers_recv(_ub_reg, recv_off, _ub_reg, recv_off, slice_bytes, _ub_comm, prev_rank, l_stream_recv[r]); + + // RECORD for the next step's GEMM + NVTE_CHECK_CUDA(cudaEventRecord(get_event(next_recv_chunk_id, r), l_stream_recv[r])); + } + } + } + + // Copy all-gathered B + if (B_copy.numel() > 0) { + // Wait for all recv streams + for (int r = 0; r < num_rings; r++) { + // Find the last chunk_id this ring received + int last_chunk = (_tp_size + _tp_id - ((_tp_size - 1) * strides[r])) % _tp_size; + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[0], get_event(last_chunk, r), 0)); + } + + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, l_stream_send[0])); + } + + _ub_comm->sms = ori_sms; + + // Final sync + for (auto& s : _stream_compute) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, s)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, l_stream_send[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, l_stream_recv[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + } + + for (auto& ev : slice_events) { + NVTE_CHECK_CUDA(cudaEventDestroy(ev)); + } +} // CommOverlapP2PBase::rocm_split_overlap_ag_smr*/ + +} // namespace transformer_engine diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 65da58d5f..c6b40cd01 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -375,8 +375,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, cudaMalloc(reinterpret_cast(&(*comm)->flags_baseptr), 2 * GPU_PAGE_SIZE)); NVTE_CHECK_CUDA(cudaMemset((*comm)->flags_baseptr, 0, 2 * GPU_PAGE_SIZE)); (*comm)->flags = reinterpret_cast( +#ifdef __HIP_PLATFORM_AMD__ + (reinterpret_cast((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); +#else ((CUdeviceptr)(*comm)->flags_baseptr + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); - +#endif using namespace std; sched_param param; @@ -511,7 +514,7 @@ void destroy_communicator_mpi(communicator *comm) { } int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { - if (comm->free_region > NVTE_MAX_REGIONS) return -1; + if (comm->free_region >= NVTE_MAX_REGIONS) return -1; int hndl = comm->free_region; comm->peer_ptr[hndl] = reinterpret_cast(malloc(sizeof(void *) * (comm->nvsize))); size_t aligned_size = bytes; @@ -670,9 +673,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * reinterpret_cast(&memhndl), sizeof(cudaIpcMemHandle_t), comm->comm_intra); + // Check for NVLINK support before attempting IPC operations + if (comm->nvsize > 1) { + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + cudaDeviceProp deviceProp; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device)); + bool peer_access_available = false; + for (int i = 0; i < comm->nvsize; i++) { + if (i != comm->nvrank) { + int can_access_peer; + cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i); + if (peer_result == cudaSuccess && can_access_peer) { + peer_access_available = true; + break; + } + } + } + if (!peer_access_available) { + free(tmp); + NVTE_ERROR( + "No peer-to-peer access available between GPUs. This platform does not support the " + "GPU-to-GPU " + "communication required for multi-GPU userbuffers. Consider using single-GPU mode."); + return 1; + } + } + for (int i = 0; i < comm->nvsize; i++) { if (i != comm->nvrank) { - NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) + NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], cudaIpcMemLazyEnablePeerAccess)); } } @@ -693,4 +723,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->mem_ptr[hndl] = *gpubuff; return comm->free_region++; + printf("***** Returning *****\n"); } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 1211392e4..adf7bbcbd 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -5,6 +5,15 @@ ************************************************************************/ #include + +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#include "amd_detail/hip_float8.h" +#define half_dtype hip_bfloat16 +#define __nv_fp8_e5m2 te_hip_fp8_e5m2 +#define __nv_fp8_e4m3 te_hip_fp8_e4m3 +#else #include #include @@ -13,6 +22,7 @@ #else #define half_dtype half #endif +#endif #include #include @@ -24,6 +34,7 @@ #define MAX_THREADS 1024 +#if !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) #define ATOMIC_CONSUMER(chunk) \ if (counters) { \ if (threadIdx.x == 0 && blockIdx.x == 0) { \ @@ -34,6 +45,18 @@ } \ if (blockIdx.x == 0) __syncthreads(); \ } +#else +#define ATOMIC_CONSUMER(chunk) \ + if (counters) { \ + if (threadIdx.x == 0 && blockIdx.x == 0) { \ + while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ + } \ + ((unsigned int *)counters)[chunk] = 1; \ + __threadfence_system(); \ + } \ + if (blockIdx.x == 0) __syncthreads(); \ + } +#endif #define ATOMIC_PRODUCER(chunk) \ if (counters) { \ @@ -1025,7 +1048,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // reset counter for next producer. ((unsigned int *)counters)[0] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence_system(); +#endif } } __syncthreads(); @@ -1116,7 +1143,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // reset counter for next producer. ((unsigned int *)counters)[chunk_i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence_system(); +#endif } } __syncthreads(); @@ -1357,6 +1388,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } } // fp16 inplace allgather kernel (Volta,Hopper) +#ifndef __HIP_PLATFORM_AMD__ #define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ cudaLaunchAttribute attribute_ub[2]; \ @@ -1367,6 +1399,15 @@ __global__ void __launch_bounds__(MAX_THREADS) attribute_ub[0].id = cudaLaunchAttributeCooperative; \ cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; +#else +#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[1]; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + attribute_ub[0].value.cooperative = 1; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = 1; +#endif #if (CUDART_VERSION >= 12030) #define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ @@ -1378,6 +1419,11 @@ __global__ void __launch_bounds__(MAX_THREADS) #define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2 #endif +#ifdef __HIP_PLATFORM_AMD__ +#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ + cudaLaunchConfig_t cfg; \ + NVTE_ERROR("SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT is not supported for AMD GPUs") +#else #define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \ @@ -1389,6 +1435,7 @@ __global__ void __launch_bounds__(MAX_THREADS) attribute_ub[0].id = cudaLaunchAttributeCooperative; \ cfg.attrs = attribute_ub; \ cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH; +#endif #define callranks_ag(x) \ if (ar_nvsize == x) { \ @@ -2196,7 +2243,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // Decrement atomic val to signal current output tile finish if (counters) { ((unsigned int *)counters)[0] = 0; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence_system(); +#endif } } } @@ -2267,7 +2318,11 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat // Decrement atomic val to signal current output tile finish if (counters) { ((unsigned int *)counters)[recv_chunk_id /*chunk_i+1*/] = 0; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence_system(); +#endif } } @@ -2319,6 +2374,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds if (comm->push == 0) { kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), reinterpret_cast(flagptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + srcoffset; void *dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset; @@ -2516,8 +2572,11 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast(flagptr), reinterpret_cast(srcptr), reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16, comm->ub_timeout); - if (!signalonly) + NVTE_CHECK_CUDA(cudaGetLastError()); + if (!signalonly) { kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); + NVTE_CHECK_CUDA(cudaGetLastError()); + } if (comm->use_ce) { NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); } @@ -2532,6 +2591,33 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds reinterpret_cast(0 ? // temporary disable GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) : nullptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} + +void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, int world_rank, communicator *comm, cudaStream_t stream) { + int rank_round_tp = (world_rank / tp_size) * tp_size; + for (int j = 1; j < tp_size; j++) { + int i = (tp_rank + j) % tp_size; + int send_offset = srcoffset + bytes_per_slice * tp_rank; + int recv_offset = dstoffset + bytes_per_slice * tp_rank; + userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, + rank_round_tp + i, stream); + } +} + +void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, int world_rank, communicator *comm, cudaStream_t stream) { + int rank_round_tp = (world_rank / tp_size) * tp_size; + for (int j = tp_size - 1; j > 0; j--) { + int i = (tp_rank + j) % tp_size; + int send_offset = srcoffset + bytes_per_slice * i; + int recv_offset = dstoffset + bytes_per_slice * i; + userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, + rank_round_tp + i, stream); } } @@ -2545,7 +2631,11 @@ static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { // COMM kernel need to explicitely flash gmem. // GEMM kernel already executed, and can not see gmem // change without COMM kernel explicitely make change +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence_system(); +#endif } // consumer @@ -2555,7 +2645,11 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { while (0 != (atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) { } ((unsigned int *)atomic_ptr)[chunk_i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence_system(); +#endif } } @@ -2567,7 +2661,11 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i while (0 != (atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) { } ((unsigned int *)atomic_ptr)[i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence_system(); +#endif } } } @@ -2588,24 +2686,28 @@ void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); producer_kernel<<>>(atomic_ptr, chunk_i); + NVTE_CHECK_CUDA(cudaGetLastError()); } void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); consumer_kernel<<>>(atomic_ptr, chunk_i); + NVTE_CHECK_CUDA(cudaGetLastError()); } void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) { dim3 block(1); dim3 grid(1); consumer_batch_kernel<<>>(atomic_ptr, first_chunk_i, num_chunks); + NVTE_CHECK_CUDA(cudaGetLastError()); } void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) { dim3 block(1); dim3 grid(1); reset_counters_kernel<<>>(atomic_ptr, num_chunks, allgather); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -2659,6 +2761,7 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in reduce_fp8_in_bf16_out_cuda <<>>(inputs, output, scale, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); + NVTE_CHECK_CUDA(cudaGetLastError()); } template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale, @@ -2714,4 +2817,5 @@ void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cud dim3 grid(num_blocks); reduce_bf16_cuda<<>>( inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); + NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 03e45b978..4d52fbb64 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -27,7 +27,7 @@ using ExtAllgatherOp = std::function; using ExtBarrierOp = std::function; -#define NVTE_MAX_REGIONS 16 +#define NVTE_MAX_REGIONS 32 #define NVTE_MAX_SMS 32 #define NVTE_MAX_OPS 32 #define NVTE_MAX_PEERS 8192 @@ -304,4 +304,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream); +void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, int world_rank, communicator *comm, cudaStream_t stream); + +void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, int world_rank, communicator *comm, cudaStream_t stream); + #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 293c57526..7d5dab28a 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -15,7 +17,7 @@ #include "common/comm_gemm_overlap/userbuffers/userbuffers.h" -#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 +#define NVTE_COMM_OVERLAP_MAX_STREAMS 7 namespace transformer_engine { @@ -36,7 +38,9 @@ enum class CommOverlapAlgo { SPLIT_PIPELINED_RS_P2P = 4, ATOMIC_GEMM_RS = 5, ATOMIC_GEMM_AG_P2P = 6, - ATOMIC_GEMM_RS_P2P = 7 + ATOMIC_GEMM_RS_P2P = 7, + EXTERNAL_BULK_OVERLAP_AG = 8, + SPLIT_PIPELINED_AG_RD_P2P = 9 }; class CommOverlapCore { @@ -57,6 +61,7 @@ class CommOverlapCore { int _comm_priority; bool _atomic_gemm{false}; bool _is_p2p{false}; + bool _use_rd{false}; TensorWrapper _ubuf; TensorWrapper _counter; @@ -66,6 +71,11 @@ class CommOverlapCore { std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; + private: + void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + public: CommOverlapCore() {} // dummy constructor for exposing type to Python @@ -77,23 +87,38 @@ class CommOverlapCore { virtual ~CommOverlapCore(); + void *get_ubuf_dptr() { return _ubuf.dptr(); } + void set_ubuf_scale_inv(float *scale_inv) { _ubuf_scale_inv = scale_inv; _ubuf_scale_inv_initialized = true; } + virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) { + NVTE_ERROR("Operation is not implemented."); + } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, const std::vector &shape); TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, const std::vector &shape); + int get_tp_size() { return _tp_size; } + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } + bool is_use_rd() { return _use_rd; } + bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + virtual bool is_aggregate() { + NVTE_ERROR("Operation is not implemented."); + } + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, @@ -133,6 +158,19 @@ class CommOverlapCore { cudaStream_t stream_main) { NVTE_ERROR("Operation is not implemented."); } + + virtual void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void rocm_split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { @@ -142,6 +180,10 @@ class CommOverlapBase : public CommOverlapCore { cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm); + public: CommOverlapBase() {} // dummy constructor for exposing type to Python @@ -198,6 +240,22 @@ class CommOverlapBase : public CommOverlapCore { TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) override; + + void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) override; + + /* + ** Split AllGather + GEMM using P2P communication using recursive doubling + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. + */ + void rocm_split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { @@ -211,9 +269,13 @@ class CommOverlapP2PBase : public CommOverlapCore { int _num_ubuf_chunks; int _self_chunk_id; std::vector _ubufs; - std::vector _stream_send; + std::vector _stream_send, l_stream_send, l_stream_recv; cudaStream_t _stream_recv; - cudaEvent_t _stop_send, _stop_recv; + cudaEvent_t _stop_send, _stop_recv, l_stop_recv[7]; + + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate, bool use_rd); public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python @@ -224,10 +286,13 @@ class CommOverlapP2PBase : public CommOverlapCore { CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, - bool atomic_gemm = false, bool aggregate = false); + bool atomic_gemm = false, bool aggregate = false, bool use_rd = false); virtual ~CommOverlapP2PBase(); + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, @@ -277,6 +342,26 @@ class CommOverlapP2PBase : public CommOverlapCore { TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) override; + + /* + ** Split AllGather + GEMM using P2P communication using recursive doubling + */ + void rocm_split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; + + bool is_aggregate() { return _aggregate; } // needed for rocm pathing + + /* + ** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object. + ** The gemm for overlap_gemm is assumed to have been previously started. + */ + void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } }; // CommOverlapP2PBase } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/multi_stream.h b/transformer_engine/common/include/transformer_engine/multi_stream.h index e406a0786..cf67711f1 100644 --- a/transformer_engine/common/include/transformer_engine/multi_stream.h +++ b/transformer_engine/common/include/transformer_engine/multi_stream.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -11,7 +13,11 @@ #ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H #define TRANSFORMER_ENGINE_MULTI_STREAM_H +#ifdef __HIP_PLATFORM_AMD__ +#include "util/hip_runtime.h" +#else #include "cuda_runtime.h" +#endif #ifdef __cplusplus extern "C" { diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 896f09e50..b813b2549 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -27,7 +27,7 @@ namespace { #include "string_path_cuda_include.h" } // namespace -#endif // __HIP_PLATFORM_AMD__ +#endif // #ifndef __HIP_PLATFORM_AMD__ int num_devices() { auto query_num_devices = []() -> int { @@ -103,7 +103,6 @@ int sm_count(int device_id) { return cache[device_id]; } -#ifndef __HIP_PLATFORM_AMD__ void stream_priority_range(int *low_priority, int *high_priority, int device_id) { static std::vector> cache(num_devices()); static std::vector flags(num_devices()); @@ -124,6 +123,11 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id) *high_priority = cache[device_id].second; } +#ifdef __HIP_PLATFORM_AMD__ +bool supports_multicast(int _) { + return false; +} +#else bool supports_multicast(int device_id) { #if CUDART_VERSION >= 12010 // NOTE: This needs to be guarded at compile-time and run-time because the diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index 58712c9d9..c0d46992a 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -50,7 +50,6 @@ const std::string &sm_arch_name(int device_id = -1); */ int sm_count(int device_id = -1); -#ifndef __HIP_PLATFORM_AMD__ /* \brief Minimum and maximum stream priorities supported on device * * \param[in] device_id CUDA device (default is current device) @@ -69,6 +68,7 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id */ bool supports_multicast(int device_id = -1); +#ifndef __HIP_PLATFORM_AMD__ /* \brief Path to CUDA Toolkit headers * * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index be06e807e..b4c1eaad0 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -10,10 +10,7 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include -//TODO: rocm does not support comm gemm overlap yet -#ifndef USE_ROCM #include -#endif #include #include @@ -35,9 +32,6 @@ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); #endif -// Define comm overlap handles if not using ROCm -#ifndef USE_ROCM - #define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \ pybind11::enum_(m, "CommOverlapType", \ pybind11::module_local()) \ @@ -54,7 +48,11 @@ transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ + .value("EXTERNAL_BULK_OVERLAP_AG", \ + transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG) \ + .value("SPLIT_PIPELINED_AG_RD_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_RD_P2P); \ py::class_>(m, "CommOverlapCore", \ pybind11::module_local()) \ @@ -89,14 +87,6 @@ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ py::call_guard()); -#else -#define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \ - pybind11::class_(m, "CommOverlapType", \ - pybind11::module_local()); \ - py::class_>(m, "CommOverlapCore", \ - pybind11::module_local()); -#endif #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ pybind11::enum_(m, "DType", pybind11::module_local()) \ diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 29875584b..c1646793a 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -30,9 +30,7 @@ #include #include #include -#ifndef USE_ROCM #include -#endif #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 72151f41a..24c7151d9 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -13,14 +13,6 @@ #include "common.h" -#ifdef USE_ROCM -namespace transformer_engine { -//dummy CommOverlapCore, CommOverlapType in rocm -class CommOverlapCore{}; -class CommOverlapType{}; -} -#endif - namespace transformer_engine::pytorch { /*************************************************************************************************** @@ -449,7 +441,6 @@ void rocshmem_finalize(); } // namespace transformer_engine::pytorch -#ifndef USE_ROCM /*************************************************************************************************** * Comm+GEMM Overlap Wrappers **************************************************************************************************/ @@ -509,7 +500,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, - bool aggregate = false); + bool aggregate = false, bool use_rd = false); ~CommOverlapP2P() {} @@ -521,6 +512,5 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm at::Stream get_communication_stream(); }; // CommOverlapP2P -#endif // !USE_ROCM #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 737c1d707..77b683d10 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -5,7 +5,6 @@ * * See LICENSE for license information. ************************************************************************/ -#ifndef USE_ROCM #include "../extensions.h" #include "transformer_engine/transformer_engine.h" @@ -231,14 +230,14 @@ CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::Scal te::CommOverlapType comm_type, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, - bool aggregate) + bool aggregate, bool use_rd) : te::CommOverlapP2PBase( buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, - atomic_gemm, aggregate) {} + atomic_gemm, aggregate, use_rd) {} /* ** Copy input to _ubufs[0] @@ -310,4 +309,3 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional gemm(py::handle A, bool transa, py::handle B, bool trans std::move(swizzle_scaling_factors(B_tensor, !transb))); if (comm_overlap) { -#ifndef USE_ROCM // Prepare extra output tensor TensorWrapper extra_output_tensor; if (extra_output.has_value()) { @@ -196,7 +195,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans extra_output_tensor = makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); } - // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ @@ -213,6 +211,13 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); + } else if (comm_overlap->is_use_rd() && !comm_overlap->is_aggregate()) { + NVTE_SCOPED_GIL_RELEASE({ + comm_overlap->rocm_split_overlap_ag_rd(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + }); } else { NVTE_SCOPED_GIL_RELEASE({ comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, @@ -238,9 +243,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans }); } } -#else - NVTE_ERROR("ROCm TE does not support comm_overlap\n"); -#endif //!USE_ROCM } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 14f5c83a4..512ac7b12 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -422,7 +422,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("GRAD_OUTPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT3) .value("GRAD_INPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT3); -#ifndef USE_ROCM py::class_(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) .def(py::init>(), @@ -450,21 +449,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m, "CommOverlapP2P") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool, - bool>(), + bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, - py::arg("use_ce") = true, py::arg("aggregate") = false) + py::arg("use_ce") = true, py::arg("aggregate") = false, py::arg("use_rd") = false) .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); -#else - m.def("CommOverlapHelper", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); - m.def("CommOverlap", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); - m.def("CommOverlapP2P", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); -#endif //USE_ROCM } diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a6ab1b22a..d824cc765 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -15,6 +15,7 @@ from contextlib import contextmanager import logging from types import MethodType +from itertools import chain import torch import torch.nn.functional as F @@ -276,6 +277,7 @@ def initialize_ub( "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "pipeline": ["proj_fprop", "fc2_fprop"], "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], + "recursive_doubling": [], } # AG-RS overlap pairs of layers forming a tensor-parallel block @@ -310,6 +312,7 @@ def get_default_config(name): "comm_priority": _MAX_STREAM_PRIORITY, "gemm_priority": _MIN_STREAM_PRIORITY, "pipeline_rs_overlap_first_gemm": False, + "use_rd": False, } return default_cfg @@ -328,6 +331,7 @@ def add_ub( comm_priority: int = 0, gemm_priority: int = 0, pipeline_rs_overlap_first_gemm: bool = False, + use_rd: bool = False, ) -> None: if atomic_gemm: warnings.warn( @@ -364,7 +368,8 @@ def add_ub( assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype - if method == "ring_exchange": + use_rd = method == "recursive_doubling" + if method == "ring_exchange" or use_rd: ub_obj = tex.CommOverlapP2P( shape, # Communication buffer shape buffer_dtype, # Communication buffer data type @@ -380,6 +385,7 @@ def add_ub( aggregate=aggregate, gemm_priority=gemm_priority, comm_priority=comm_priority, + use_rd=use_rd, ) else: ub_obj = tex.CommOverlap( @@ -411,7 +417,7 @@ def add_ub( new_method = ub_cfgs[name]["method"] methods[new_method].append(name) - for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: + for name in chain.from_iterable(methods.values()): ub_cfg = get_default_config(name) if ub_cfgs is not None and name in ub_cfgs: fp8_buf = (name in layers_all_gather_overlap) or (