Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.venv/
.vscode/
__pycache__/
1 change: 0 additions & 1 deletion et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ class BaseBackend(ABC):
def __init__(self) -> None:
self.tcp_store = None
self.collectiveFunc = {
"all_to_all_single": self.all_to_all_single, # pyre-ignore[16]:
"all_to_all": self.all_to_all,
"all_to_allv": self.all_to_allv,
"all_reduce": self.all_reduce,
Expand Down
20 changes: 1 addition & 19 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def all_to_all(
return work

def all_to_allv(self, collectiveArgs, retFlag=False, pair=False):
# cpp layer all_to_allv is corresponding to python layer all_to_all_single
# pair=True mode does not support quantization
if (
collectiveArgs.all2all_qcomm
Expand Down Expand Up @@ -301,25 +302,6 @@ def all_to_allv(self, collectiveArgs, retFlag=False, pair=False):
if retFlag:
return work

def all_to_all_single(self, collectiveArgs, retFlag=False, pair=False):
# does not support quantization
if collectiveArgs.all2all_qcomm:
logger.warn("all_to_all_single does not support quantization")
return

work = dist.all_to_all_single(
collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair,
collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair,
group=collectiveArgs.group,
async_op=collectiveArgs.asyncOp,
)

if collectiveArgs.asyncOp:
collectiveArgs.waitObj.append(work)

if retFlag:
return work

def all_gather(self, collectiveArgs, retFlag=False, pair=False):
if self.use_ext_dist:
retObj = collectiveArgs.group.all_gather(
Expand Down
60 changes: 15 additions & 45 deletions et_replay/comm/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def fixBeginSize(commsParams: commsParamsHolder, world_size: int) -> None:
if commsParams.collective in (
"all_to_all",
"all_to_allv",
"all_to_all_single",
"all_gather",
"all_gather_base",
"gather",
Expand Down Expand Up @@ -300,14 +299,13 @@ def checkQuantArgs(
if collective not in (
"all_to_all",
"all_to_allv",
"all_to_all_single",
"reduce",
"all_reduce",
):
raise NotImplementedError(
f"quantized communication for {collective} is currently unsupported."
)
if collective in ("all_to_all", "all_to_allv", "all_to_all_single"):
if collective in ("all_to_all", "all_to_allv"):
if (beginSize // 4) % quant_a2a_embedding_dim != 0:
logger.warning(
f"begin size {beginSize} must be a multiple of --quant-a2a-embedding-dim {quant_a2a_embedding_dim} for all_to_all operation"
Expand Down Expand Up @@ -349,7 +347,6 @@ def paramToCommName(name: str, supported_comms: list[str] | None = None) -> str:
"alltoall": "all_to_all",
"alltoallv": "all_to_allv",
"alltoallbase": "all_to_allv",
"alltoallsingle": "all_to_all_single",
"allreduce": "all_reduce",
"allgather": "all_gather",
"allgatherbase": "all_gather_base",
Expand Down Expand Up @@ -880,58 +877,28 @@ def _prep_all_to_allv(
opTensor = torch.Tensor()
if allocate:
# all_to_allv requires two tensors
# ipTensor has been allocated outside of this function, just pass in
opTensor = self.backendFuncs.alloc_random(
[numElementsOut], curDevice, dtype, scaleFactor
)
# recorded splits in trace is only for dim 0, but tensor in replay has been flattened.
# need to recalculate the splits for flattened 1D tensor
# corner case: one rank sends zeor data out, but receives data from other ranks, and vice versa.
self.collectiveArgs.opTensor_split = (
[numElementsOut // sum(curComm.outSplit) * i for i in curComm.outSplit]
[
numElementsOut // max(sum(curComm.outSplit), 1) * i
for i in curComm.outSplit
]
if curComm.outSplit
else None
)
self.collectiveArgs.ipTensor_split = (
[numElementsIn // sum(curComm.inSplit) * i for i in curComm.inSplit]
[numElementsIn // max(sum(curComm.inSplit), 1) * i for i in curComm.inSplit]
if curComm.inSplit
else None
)
return (ipTensor, opTensor)

def _prep_all_to_all_single(
self,
ipTensor: torch.Tensor,
curComm: commsArgs,
commsParams: commsParamsHolderBase,
numElementsIn: int,
numElementsOut: int,
world_size: int,
curDevice: str,
dtype: torch.dtype,
scaleFactor: float,
allocate: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
ipTensor = torch.Tensor()
opTensor = torch.Tensor()
if allocate:
if commsParams.dcheck == 1:
ipTensor = self.backendFuncs.alloc_ones(
[numElementsIn],
curDevice,
commsParams.dtype,
self.initVal,
)
else:
ipTensor = self.backendFuncs.alloc_random(
[numElementsIn],
curDevice,
commsParams.dtype,
scaleFactor,
)
opTensor = self.backendFuncs.alloc_random(
[numElementsOut], curDevice, dtype, scaleFactor
)
return (ipTensor, opTensor)

def _prep_all_to_all(
self,
ipTensor: list[torch.Tensor],
Expand All @@ -948,17 +915,21 @@ def _prep_all_to_all(
ipTensor = []
opTensor = []
if allocate:
alloc_func = (
i_alloc_func = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanshang-nv, what is the reason for the change in this code block

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which point?

  1. i_ should be prefix of input_.
  2. i_scale_factor is the same usage as other prepare function, which is not used before this PR.
  3. the input and output of all_to_all is a list of tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I ask is why do you change to logic to create data for input/output tensors?

The original code is to create tensor with initVal when check_data is true, otherwise fill the tensor with all one.
With your change, you create data differently for input and output tensors. In some cases, you use scaleFactor. What is the reason behind this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous logic:
if dcheck is true, use alloc_ones with initVal to create both input and output tensor. Problem is, output tensor should always use random for check.
if dcheck is false, use alloc_random, but still use initVal. It's wrong, should use scale_factor. Otherwise, why pass parameter scaleFactor in.

Fixed logic in PR:
if dcheck is true, use alloc_ones with initVal.
if dcheck is false, use alloc_random with scaleFactor.
output tensor shoudl always be allocated with alloc_random.

take

def _prep_all_gather(
as a reference.

@shengfukevin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I think what value to initialize output tensor should not matter. right? Since it will be overwritten.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need a stamp from Meta side to get it approved.

self.backendFuncs.alloc_ones
if commsParams.dcheck == 1
else self.backendFuncs.alloc_random
)
i_scale_factor = self.initVal if commsParams.dcheck == 1 else scaleFactor
ipTensor = [
alloc_func(i, curDevice, commsParams.dtype, self.initVal)
i_alloc_func([i], curDevice, commsParams.dtype, i_scale_factor)
for i in curComm.inSplit
]

opTensor = [
alloc_func(i, curDevice, commsParams.dtype, self.initVal)
self.backendFuncs.alloc_random(
[i], curDevice, commsParams.dtype, scaleFactor
)
for i in curComm.outSplit
]
return (ipTensor, opTensor)
Expand Down Expand Up @@ -1247,7 +1218,6 @@ def prepComm(
# TODO: consider using this dictionary to check valid keywords rather than silently defaulting

dispatchDict = {
"all_to_all_single": self._prep_all_to_all_single,
"all_to_allv": self._prep_all_to_allv,
"all_to_all": self._prep_all_to_all,
"all_gather": self._prep_all_gather,
Expand Down
109 changes: 97 additions & 12 deletions et_replay/comm/profiler_trace_analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import ast
import functools
import json
import logging
import os
import pathlib
import re
import time
from collections import defaultdict
from typing import Any, Callable, Dict

Expand All @@ -12,6 +15,21 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def timer_decorator(func):
"""Decorator that prints the execution time of a function"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"{func.__name__} took {end_time - start_time:.2f} seconds")
return result

return wrapper


# refer to:
# https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/c10/core/ScalarType.h#L61
_dtype_size_map: Dict[str, int] = {
Expand Down Expand Up @@ -139,7 +157,54 @@ def _get_event_busbw_factor(evt):
return correction_factor_func(group_size)


def calculate_bw_(trace_data):
def _is_uneven_all_to_all_evt(evt):
coll_name = _get_dict_value(
evt["args"], "Collective name", f'Missing "Collective name" in event: {evt}'
)
return coll_name in ["all_to_all", "all_to_allv"] and (
ast.literal_eval(evt["args"]["In split size"])
or ast.literal_eval(evt["args"]["Out split size"])
)


def _get_uneven_all_to_all_data_size(evt, global_rank):
group_size = evt["args"]["Group size"]
local_rank = _parse_ranks(evt["args"]["Process Group Ranks"], group_size).index(
global_rank
)
in_elems_count = evt["args"]["In msg nelems"]
out_elems_count = evt["args"]["Out msg nelems"]
in_split_size = ast.literal_eval(evt["args"]["In split size"])
out_split_size = ast.literal_eval(evt["args"]["Out split size"])
dtype_size = _dtype_size_map[evt["args"]["dtype"]]

if (in_split_size and in_split_size[-1] == Ellipsis) or (
out_split_size and out_split_size[-1] == Ellipsis
):
in_split_size = []
out_split_size = []
logger.warning(f"Fallback to even all2all bw calculation for event: {evt}")

if in_split_size:
send_elems = in_elems_count - in_split_size[local_rank]
else:
send_elems = in_elems_count / group_size * (group_size - 1)

if out_split_size:
recv_elems = out_elems_count - out_split_size[local_rank]
else:
recv_elems = out_elems_count / group_size * (group_size - 1)

return max(send_elems, recv_elems) * dtype_size


def _calculate_busbw_for_uneven_all_to_all(evt, global_rank):
return round(
_get_uneven_all_to_all_data_size(evt, global_rank) / evt["dur"] * 1e-3, 2
)


def calculate_bw_(trace_data, global_rank):
nccl_events = [
i
for i in trace_data["traceEvents"]
Expand All @@ -163,7 +228,11 @@ def calculate_bw_(trace_data):

algbw = _calculate_algbw(evt)
busbw_factor = _get_event_busbw_factor(evt)
busbw = round(algbw * busbw_factor, 2)
if _is_uneven_all_to_all_evt(evt):
# calculate busbw for uneven all_to_all
busbw = _calculate_busbw_for_uneven_all_to_all(evt, global_rank)
else:
busbw = round(algbw * busbw_factor, 2)

evt["args"]["algbw (GB/sec)"] = algbw
evt["args"]["busbw (GB/sec)"] = busbw
Expand All @@ -178,7 +247,7 @@ def calculate_bw_(trace_data):
logger.error(f"- Error: {err_msg}")


def calculate_sbw(trace_data):
def calculate_sbw(trace_data, global_rank):
# calculate shared bw per rank
nccl_events = [
i
Expand All @@ -193,6 +262,8 @@ def calculate_sbw(trace_data):

total_data_size = sum(
_calculate_event_data_size(evt) * _get_event_busbw_factor(evt)
if not _is_uneven_all_to_all_evt(evt)
else _get_uneven_all_to_all_data_size(evt, global_rank)
for evt in nccl_events
)

Expand Down Expand Up @@ -232,25 +303,36 @@ def pick_iter_e2e_time_(trace_data, tl):

def pick_comm_bw_(trace_data, comm_bw_data):
rank = trace_data["distributedInfo"]["rank"]

group_ranks_to_pg_id = defaultdict(list)
for pg in trace_data["distributedInfo"]["pg_config"]:
group_ranks_to_pg_id[tuple(pg["ranks"])].append(int(pg["pg_name"]))
for ranks in group_ranks_to_pg_id:
group_ranks_to_pg_id[ranks].sort()

nccl_events = [
i
for i in trace_data["traceEvents"]
if i.get("cat", "") == "kernel"
and i["name"].startswith(("ncclDevKernel_", "ncclKernel_"))
and "algbw (GB/sec)" in i["args"]
]
pg_name2config = {
pg["pg_name"]: pg for pg in trace_data["distributedInfo"]["pg_config"]
}
for evt in nccl_events:
knl_name = evt["name"][: evt["name"].index("(")]
coll_name = evt["args"]["Collective name"]
data_size = _calculate_event_data_size(evt)
ranks_count = evt["args"]["Group size"]

ranks = _parse_ranks(evt["args"]["Process Group Ranks"], ranks_count)
ranks_count = evt["args"]["Group size"]
pg_id = int(evt["args"]["Process Group Name"])
pg = (*ranks, pg_id) if ranks and rank == min(ranks) else None
ranks = pg_name2config[evt["args"]["Process Group Name"]]["ranks"]

# If there are multiple process groups with the same ranks, the last element
# of this tuple is the idential index to differentiate them across ranks.
pg = (*ranks, group_ranks_to_pg_id[tuple(ranks)].index(pg_id))

# TODO: calculation of unbalanced all2all bw needs to be improved
# all2all is implemented by single ncclDevKernel_SendRecv() in NCCL
comm_bw_data[(knl_name, coll_name, data_size, ranks_count)].append(
[
evt["dur"],
Expand All @@ -261,6 +343,7 @@ def pick_comm_bw_(trace_data, comm_bw_data):
)


@timer_decorator
def analyze_profiler_trace(trace_dir: str, report_dir: str):
"""
Analyse input PyTorch profiler trace (i.e. Kineto trace) and generate report.
Expand All @@ -282,7 +365,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
# list of shared bw
sbw_lst = []

# key is (kernel_name, data size, ranks number)
# key is (kernel_name, coll name, data size, ranks count)
# value is list of [dur, algbw, busbw, pg]
comm_bw_data = defaultdict(list)

Expand All @@ -293,13 +376,15 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
with open(fpath.path, "r", encoding="utf-8") as f:
trace = json.load(f)

calculate_bw_(trace)
global_rank = trace["distributedInfo"]["rank"]
calculate_bw_(trace, global_rank)

with open(
os.path.join(processed_trace_dir, fpath.name), "w", encoding="utf-8"
) as f:
json.dump(trace, f)

sbw_lst.append(calculate_sbw(trace))
sbw_lst.append(calculate_sbw(trace, global_rank))

pick_iter_e2e_time_(trace, iter_e2e_time)
pick_comm_bw_(trace, comm_bw_data)
Expand Down Expand Up @@ -330,7 +415,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
f"avg. E2ETime of iters among all ranks: {sum(iter_e2e_time) / len(iter_e2e_time) / 1e3 :.3f} ms\n"
)
f.write(
f"avg. SharedBW (i.e. sum(data_size * busbw_factor) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n"
f"avg. SharedBW (i.e. sum(busbw_data_size) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n"
)

f.write(
Expand Down
1 change: 1 addition & 0 deletions et_replay/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ version = "0.5.0"
dependencies = [
"numpy",
"intervaltree",
"pydot",
]

[tool.setuptools.package-dir]
Expand Down