From 457f45cd4f6887bcb0f226dc3c4f29c3d849ae7a Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Tue, 20 May 2025 15:20:51 -0700 Subject: [PATCH] Shengf/fix all2all (#206) Summary: Fix support to replay all2all. Test Plan: constructed 4 rank case to invoke torch.distributed.all_to_all() and torch.distributed.all_to_all_single(), then dump trace and replay. Differential Revision: D75101007 Pulled By: shengfukevin --- .gitignore | 1 + et_replay/comm/backend/base_backend.py | 1 - .../comm/backend/pytorch_dist_backend.py | 20 +--- et_replay/comm/comms_utils.py | 60 +++------- et_replay/comm/profiler_trace_analysis.py | 109 ++++++++++++++++-- et_replay/pyproject.toml | 1 + 6 files changed, 115 insertions(+), 77 deletions(-) diff --git a/.gitignore b/.gitignore index a230a78a..93163afa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .venv/ +.vscode/ __pycache__/ diff --git a/et_replay/comm/backend/base_backend.py b/et_replay/comm/backend/base_backend.py index 9fb708f6..a81b0096 100644 --- a/et_replay/comm/backend/base_backend.py +++ b/et_replay/comm/backend/base_backend.py @@ -127,7 +127,6 @@ class BaseBackend(ABC): def __init__(self) -> None: self.tcp_store = None self.collectiveFunc = { - "all_to_all_single": self.all_to_all_single, # pyre-ignore[16]: "all_to_all": self.all_to_all, "all_to_allv": self.all_to_allv, "all_reduce": self.all_reduce, diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index 88606aa7..0b000ddc 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -241,6 +241,7 @@ def all_to_all( return work def all_to_allv(self, collectiveArgs, retFlag=False, pair=False): + # cpp layer all_to_allv is corresponding to python layer all_to_all_single # pair=True mode does not support quantization if ( collectiveArgs.all2all_qcomm @@ -301,25 +302,6 @@ def all_to_allv(self, collectiveArgs, retFlag=False, pair=False): if retFlag: return work - def all_to_all_single(self, collectiveArgs, retFlag=False, pair=False): - # does not support quantization - if collectiveArgs.all2all_qcomm: - logger.warn("all_to_all_single does not support quantization") - return - - work = dist.all_to_all_single( - collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair, - collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair, - group=collectiveArgs.group, - async_op=collectiveArgs.asyncOp, - ) - - if collectiveArgs.asyncOp: - collectiveArgs.waitObj.append(work) - - if retFlag: - return work - def all_gather(self, collectiveArgs, retFlag=False, pair=False): if self.use_ext_dist: retObj = collectiveArgs.group.all_gather( diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index d67c0597..593e23bb 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -114,7 +114,6 @@ def fixBeginSize(commsParams: commsParamsHolder, world_size: int) -> None: if commsParams.collective in ( "all_to_all", "all_to_allv", - "all_to_all_single", "all_gather", "all_gather_base", "gather", @@ -300,14 +299,13 @@ def checkQuantArgs( if collective not in ( "all_to_all", "all_to_allv", - "all_to_all_single", "reduce", "all_reduce", ): raise NotImplementedError( f"quantized communication for {collective} is currently unsupported." ) - if collective in ("all_to_all", "all_to_allv", "all_to_all_single"): + if collective in ("all_to_all", "all_to_allv"): if (beginSize // 4) % quant_a2a_embedding_dim != 0: logger.warning( f"begin size {beginSize} must be a multiple of --quant-a2a-embedding-dim {quant_a2a_embedding_dim} for all_to_all operation" @@ -349,7 +347,6 @@ def paramToCommName(name: str, supported_comms: list[str] | None = None) -> str: "alltoall": "all_to_all", "alltoallv": "all_to_allv", "alltoallbase": "all_to_allv", - "alltoallsingle": "all_to_all_single", "allreduce": "all_reduce", "allgather": "all_gather", "allgatherbase": "all_gather_base", @@ -880,58 +877,28 @@ def _prep_all_to_allv( opTensor = torch.Tensor() if allocate: # all_to_allv requires two tensors + # ipTensor has been allocated outside of this function, just pass in opTensor = self.backendFuncs.alloc_random( [numElementsOut], curDevice, dtype, scaleFactor ) # recorded splits in trace is only for dim 0, but tensor in replay has been flattened. # need to recalculate the splits for flattened 1D tensor + # corner case: one rank sends zeor data out, but receives data from other ranks, and vice versa. self.collectiveArgs.opTensor_split = ( - [numElementsOut // sum(curComm.outSplit) * i for i in curComm.outSplit] + [ + numElementsOut // max(sum(curComm.outSplit), 1) * i + for i in curComm.outSplit + ] if curComm.outSplit else None ) self.collectiveArgs.ipTensor_split = ( - [numElementsIn // sum(curComm.inSplit) * i for i in curComm.inSplit] + [numElementsIn // max(sum(curComm.inSplit), 1) * i for i in curComm.inSplit] if curComm.inSplit else None ) return (ipTensor, opTensor) - def _prep_all_to_all_single( - self, - ipTensor: torch.Tensor, - curComm: commsArgs, - commsParams: commsParamsHolderBase, - numElementsIn: int, - numElementsOut: int, - world_size: int, - curDevice: str, - dtype: torch.dtype, - scaleFactor: float, - allocate: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - ipTensor = torch.Tensor() - opTensor = torch.Tensor() - if allocate: - if commsParams.dcheck == 1: - ipTensor = self.backendFuncs.alloc_ones( - [numElementsIn], - curDevice, - commsParams.dtype, - self.initVal, - ) - else: - ipTensor = self.backendFuncs.alloc_random( - [numElementsIn], - curDevice, - commsParams.dtype, - scaleFactor, - ) - opTensor = self.backendFuncs.alloc_random( - [numElementsOut], curDevice, dtype, scaleFactor - ) - return (ipTensor, opTensor) - def _prep_all_to_all( self, ipTensor: list[torch.Tensor], @@ -948,17 +915,21 @@ def _prep_all_to_all( ipTensor = [] opTensor = [] if allocate: - alloc_func = ( + i_alloc_func = ( self.backendFuncs.alloc_ones if commsParams.dcheck == 1 else self.backendFuncs.alloc_random ) + i_scale_factor = self.initVal if commsParams.dcheck == 1 else scaleFactor ipTensor = [ - alloc_func(i, curDevice, commsParams.dtype, self.initVal) + i_alloc_func([i], curDevice, commsParams.dtype, i_scale_factor) for i in curComm.inSplit ] + opTensor = [ - alloc_func(i, curDevice, commsParams.dtype, self.initVal) + self.backendFuncs.alloc_random( + [i], curDevice, commsParams.dtype, scaleFactor + ) for i in curComm.outSplit ] return (ipTensor, opTensor) @@ -1247,7 +1218,6 @@ def prepComm( # TODO: consider using this dictionary to check valid keywords rather than silently defaulting dispatchDict = { - "all_to_all_single": self._prep_all_to_all_single, "all_to_allv": self._prep_all_to_allv, "all_to_all": self._prep_all_to_all, "all_gather": self._prep_all_gather, diff --git a/et_replay/comm/profiler_trace_analysis.py b/et_replay/comm/profiler_trace_analysis.py index dd5170d2..69d4a5c0 100644 --- a/et_replay/comm/profiler_trace_analysis.py +++ b/et_replay/comm/profiler_trace_analysis.py @@ -1,8 +1,11 @@ import ast +import functools import json import logging import os import pathlib +import re +import time from collections import defaultdict from typing import Any, Callable, Dict @@ -12,6 +15,21 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) + +def timer_decorator(func): + """Decorator that prints the execution time of a function""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + print(f"{func.__name__} took {end_time - start_time:.2f} seconds") + return result + + return wrapper + + # refer to: # https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/c10/core/ScalarType.h#L61 _dtype_size_map: Dict[str, int] = { @@ -139,7 +157,54 @@ def _get_event_busbw_factor(evt): return correction_factor_func(group_size) -def calculate_bw_(trace_data): +def _is_uneven_all_to_all_evt(evt): + coll_name = _get_dict_value( + evt["args"], "Collective name", f'Missing "Collective name" in event: {evt}' + ) + return coll_name in ["all_to_all", "all_to_allv"] and ( + ast.literal_eval(evt["args"]["In split size"]) + or ast.literal_eval(evt["args"]["Out split size"]) + ) + + +def _get_uneven_all_to_all_data_size(evt, global_rank): + group_size = evt["args"]["Group size"] + local_rank = _parse_ranks(evt["args"]["Process Group Ranks"], group_size).index( + global_rank + ) + in_elems_count = evt["args"]["In msg nelems"] + out_elems_count = evt["args"]["Out msg nelems"] + in_split_size = ast.literal_eval(evt["args"]["In split size"]) + out_split_size = ast.literal_eval(evt["args"]["Out split size"]) + dtype_size = _dtype_size_map[evt["args"]["dtype"]] + + if (in_split_size and in_split_size[-1] == Ellipsis) or ( + out_split_size and out_split_size[-1] == Ellipsis + ): + in_split_size = [] + out_split_size = [] + logger.warning(f"Fallback to even all2all bw calculation for event: {evt}") + + if in_split_size: + send_elems = in_elems_count - in_split_size[local_rank] + else: + send_elems = in_elems_count / group_size * (group_size - 1) + + if out_split_size: + recv_elems = out_elems_count - out_split_size[local_rank] + else: + recv_elems = out_elems_count / group_size * (group_size - 1) + + return max(send_elems, recv_elems) * dtype_size + + +def _calculate_busbw_for_uneven_all_to_all(evt, global_rank): + return round( + _get_uneven_all_to_all_data_size(evt, global_rank) / evt["dur"] * 1e-3, 2 + ) + + +def calculate_bw_(trace_data, global_rank): nccl_events = [ i for i in trace_data["traceEvents"] @@ -163,7 +228,11 @@ def calculate_bw_(trace_data): algbw = _calculate_algbw(evt) busbw_factor = _get_event_busbw_factor(evt) - busbw = round(algbw * busbw_factor, 2) + if _is_uneven_all_to_all_evt(evt): + # calculate busbw for uneven all_to_all + busbw = _calculate_busbw_for_uneven_all_to_all(evt, global_rank) + else: + busbw = round(algbw * busbw_factor, 2) evt["args"]["algbw (GB/sec)"] = algbw evt["args"]["busbw (GB/sec)"] = busbw @@ -178,7 +247,7 @@ def calculate_bw_(trace_data): logger.error(f"- Error: {err_msg}") -def calculate_sbw(trace_data): +def calculate_sbw(trace_data, global_rank): # calculate shared bw per rank nccl_events = [ i @@ -193,6 +262,8 @@ def calculate_sbw(trace_data): total_data_size = sum( _calculate_event_data_size(evt) * _get_event_busbw_factor(evt) + if not _is_uneven_all_to_all_evt(evt) + else _get_uneven_all_to_all_data_size(evt, global_rank) for evt in nccl_events ) @@ -232,6 +303,13 @@ def pick_iter_e2e_time_(trace_data, tl): def pick_comm_bw_(trace_data, comm_bw_data): rank = trace_data["distributedInfo"]["rank"] + + group_ranks_to_pg_id = defaultdict(list) + for pg in trace_data["distributedInfo"]["pg_config"]: + group_ranks_to_pg_id[tuple(pg["ranks"])].append(int(pg["pg_name"])) + for ranks in group_ranks_to_pg_id: + group_ranks_to_pg_id[ranks].sort() + nccl_events = [ i for i in trace_data["traceEvents"] @@ -239,18 +317,22 @@ def pick_comm_bw_(trace_data, comm_bw_data): and i["name"].startswith(("ncclDevKernel_", "ncclKernel_")) and "algbw (GB/sec)" in i["args"] ] + pg_name2config = { + pg["pg_name"]: pg for pg in trace_data["distributedInfo"]["pg_config"] + } for evt in nccl_events: knl_name = evt["name"][: evt["name"].index("(")] coll_name = evt["args"]["Collective name"] data_size = _calculate_event_data_size(evt) - ranks_count = evt["args"]["Group size"] - ranks = _parse_ranks(evt["args"]["Process Group Ranks"], ranks_count) + ranks_count = evt["args"]["Group size"] pg_id = int(evt["args"]["Process Group Name"]) - pg = (*ranks, pg_id) if ranks and rank == min(ranks) else None + ranks = pg_name2config[evt["args"]["Process Group Name"]]["ranks"] + + # If there are multiple process groups with the same ranks, the last element + # of this tuple is the idential index to differentiate them across ranks. + pg = (*ranks, group_ranks_to_pg_id[tuple(ranks)].index(pg_id)) - # TODO: calculation of unbalanced all2all bw needs to be improved - # all2all is implemented by single ncclDevKernel_SendRecv() in NCCL comm_bw_data[(knl_name, coll_name, data_size, ranks_count)].append( [ evt["dur"], @@ -261,6 +343,7 @@ def pick_comm_bw_(trace_data, comm_bw_data): ) +@timer_decorator def analyze_profiler_trace(trace_dir: str, report_dir: str): """ Analyse input PyTorch profiler trace (i.e. Kineto trace) and generate report. @@ -282,7 +365,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): # list of shared bw sbw_lst = [] - # key is (kernel_name, data size, ranks number) + # key is (kernel_name, coll name, data size, ranks count) # value is list of [dur, algbw, busbw, pg] comm_bw_data = defaultdict(list) @@ -293,13 +376,15 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): with open(fpath.path, "r", encoding="utf-8") as f: trace = json.load(f) - calculate_bw_(trace) + global_rank = trace["distributedInfo"]["rank"] + calculate_bw_(trace, global_rank) + with open( os.path.join(processed_trace_dir, fpath.name), "w", encoding="utf-8" ) as f: json.dump(trace, f) - sbw_lst.append(calculate_sbw(trace)) + sbw_lst.append(calculate_sbw(trace, global_rank)) pick_iter_e2e_time_(trace, iter_e2e_time) pick_comm_bw_(trace, comm_bw_data) @@ -330,7 +415,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): f"avg. E2ETime of iters among all ranks: {sum(iter_e2e_time) / len(iter_e2e_time) / 1e3 :.3f} ms\n" ) f.write( - f"avg. SharedBW (i.e. sum(data_size * busbw_factor) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n" + f"avg. SharedBW (i.e. sum(busbw_data_size) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n" ) f.write( diff --git a/et_replay/pyproject.toml b/et_replay/pyproject.toml index 19dbf0ea..00811ee1 100644 --- a/et_replay/pyproject.toml +++ b/et_replay/pyproject.toml @@ -8,6 +8,7 @@ version = "0.5.0" dependencies = [ "numpy", "intervaltree", + "pydot", ] [tool.setuptools.package-dir]