Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,24 +227,12 @@ def all_to_all(
"Not using batched embedding tables because extend distributed package not in use"
)

if isinstance(collectiveArgs.opTensor, list):
work = dist.all_to_all(
collectiveArgs.opTensor,
collectiveArgs.ipTensor,
group=self.get_collective_group(collectiveArgs),
async_op=collectiveArgs.asyncOp,
)
else:
work = dist.all_to_all_single(
collectiveArgs.opTensor
if not pair
else collectiveArgs.opTensor_pair,
collectiveArgs.ipTensor
if not pair
else collectiveArgs.ipTensor_pair,
group=self.get_collective_group(collectiveArgs),
async_op=collectiveArgs.asyncOp,
)
work = dist.all_to_all(
collectiveArgs.opTensor,
collectiveArgs.ipTensor,
group=self.get_collective_group(collectiveArgs),
async_op=collectiveArgs.asyncOp,
)

if collectiveArgs.asyncOp:
collectiveArgs.waitObj.append(work)
Expand Down
23 changes: 10 additions & 13 deletions et_replay/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import logging

import math

from et_replay import ExecutionTrace
from et_replay.comm import comms_utils
from et_replay.comm.backend.base_backend import supportedP2pOps
Expand Down Expand Up @@ -189,22 +191,17 @@ def _parse_comms_op_node( # noqa: C901
comm_args.root = comm_args.groupRanks[recorded_rank]
comm_args.groupRanks = comm_args.groupRanks

if comm_args.comms == "all_to_allv":
if comm_args.comms == "all_to_all":
# flatten each tensor and store the # of elements into split field
comm_args.inSplit = [math.prod(i) for i in node.input_shapes[0]]
comm_args.outSplit = [math.prod(i) for i in node.output_shapes[0]]
elif comm_args.comms == "all_to_allv":
if not comm_args.worldSize:
# if no pg info provided, use total ranks as world size
comm_args.worldSize = total_ranks
comm_args.inSplit = (
json.loads(node.commArgs.in_split_size)
if json.loads(node.commArgs.in_split_size)
else [int(comm_args.inMsgSize / comm_args.worldSize)]
* comm_args.worldSize
)
comm_args.outSplit = (
json.loads(node.commArgs.out_split_size)
if json.loads(node.commArgs.out_split_size)
else [int(comm_args.outMsgSize / comm_args.worldSize)]
* comm_args.worldSize
)
comm_args.inSplit = json.loads(node.commArgs.in_split_size)
comm_args.outSplit = json.loads(node.commArgs.out_split_size)

comms_op_list.append(comm_args)

return comms_op_list
Expand Down
56 changes: 21 additions & 35 deletions et_replay/comm/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,16 +876,17 @@ def _prep_all_to_allv(
opTensor = self.backendFuncs.alloc_random(
[numElementsOut], curDevice, dtype, scaleFactor
)
# all_to_allv requires tensors to specify split
# recorded splits in trace is only for dim 0, but tensor in replay has been flattened.
# need to recalculate the splits for flattened 1D tensor
self.collectiveArgs.opTensor_split = (
curComm.outSplit
if (curComm.outSplit is not None)
else [(numElementsOut // world_size) for _ in range(world_size)]
[numElementsOut // sum(curComm.outSplit) * i for i in curComm.outSplit]
if curComm.outSplit
else None
)
self.collectiveArgs.ipTensor_split = (
curComm.inSplit
if (curComm.inSplit is not None)
else [(numElementsIn // world_size) for _ in range(world_size)]
[numElementsIn // sum(curComm.inSplit) * i for i in curComm.inSplit]
if curComm.inSplit
else None
)
return (ipTensor, opTensor)

Expand Down Expand Up @@ -937,37 +938,22 @@ def _prep_all_to_all(
scaleFactor: float,
allocate: bool = True,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
# all_to_all requires two tensor lists, e.g., List[torch.Tensor]

ipTensor = []
opTensor = []
if allocate:
if commsParams.dcheck == 1:
for _ in range(world_size):
ipTensor.append(
self.backendFuncs.alloc_ones(
[(numElementsIn // world_size)],
curDevice,
commsParams.dtype,
self.initVal,
)
)
else:
for _ in range(world_size):
ipTensor.append(
self.backendFuncs.alloc_random(
[(numElementsIn // world_size)],
curDevice,
commsParams.dtype,
scaleFactor,
)
)
for _ in range(world_size):
opTensor.append(
self.backendFuncs.alloc_random(
[(numElementsOut // world_size)], curDevice, dtype, scaleFactor
)
)
alloc_func = (
self.backendFuncs.alloc_ones
if commsParams.dcheck == 1
else self.backendFuncs.alloc_random
)
ipTensor = [
alloc_func(i, curDevice, commsParams.dtype, self.initVal)
for i in curComm.inSplit
]
opTensor = [
alloc_func(i, curDevice, commsParams.dtype, self.initVal)
for i in curComm.outSplit
]
return (ipTensor, opTensor)

def _prep_all_gather(
Expand Down
15 changes: 10 additions & 5 deletions et_replay/comm/profiler_trace_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,17 @@ def pick_comm_bw_(trace_data, comm_bw_data):
]
for evt in nccl_events:
knl_name = evt["name"][: evt["name"].index("(")]
coll_name = evt["args"]["Collective name"]
data_size = _calculate_event_data_size(evt)
ranks_count = evt["args"]["Group size"]

ranks = _parse_ranks(evt["args"]["Process Group Ranks"], ranks_count)
pg_id = int(evt["args"]["Process Group Name"])
pg = (*ranks, pg_id) if ranks and rank == min(ranks) else None

comm_bw_data[(knl_name, data_size, ranks_count)].append(
# TODO: calculation of unbalanced all2all bw needs to be improved
# all2all is implemented by single ncclDevKernel_SendRecv() in NCCL
comm_bw_data[(knl_name, coll_name, data_size, ranks_count)].append(
[
evt["dur"],
evt["args"]["algbw (GB/sec)"],
Expand Down Expand Up @@ -331,25 +334,27 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
)

f.write(
f'\n{" ":>70s}|{" ":>5s}|{"AVG.":^19s}|{"p01":^8s}|{"p50":^8s}|{"p90":^8s}|{"p99":^8s}|\n'
f'\n{" ":>86s}|{" ":>5s}|{"AVG.":^19s}|{"p01":^8s}|{"p50":^8s}|{"p90":^8s}|{"p99":^8s}|\n'
)

f.write(
f'{"kernel":>50s} {"size":>12s} {"#rks":>6s}|{"#pgs":>5s}|{" dur":>10s} '
f'{"kernel":>50s} {"coll":>15s} {"size":>12s} {"#rks":>6s}|{"#pgs":>5s}|{" dur":>10s} '
)
for _ in range(5): # average, p01, p50, p90, p99
f.write(f'{" busbw":>8s}|')
f.write("\n")

f.write(
f'{" ":>50s} {" (B)":>12s} {" ":>6s}|{" ":>5s}|{" (ms)":>10s} '
f'{" ":>66s} {" (B)":>12s} {" ":>6s}|{" ":>5s}|{" (ms)":>10s} '
)
for _ in range(5): # average, p50, p90, p99
f.write(f'{"(GB/s)":>8s}|')
f.write("\n")

for k, v in comm_bw_summary.items():
f.write(f"{k[0]:>50s} {k[1]:>12d} {k[2]:>6d}|{v[0]:>5d}|{v[1]/1e3:>10.3f} ")
f.write(
f"{k[0]:>50s} {k[1]:>15s} {k[2]:>12d} {k[3]:>6d}|{v[0]:>5d}|{v[1]/1e3:>10.3f} "
)
for i in range(2, len(v)):
f.write(f"{v[i]:>8.2f}|")
f.write("\n")
2 changes: 1 addition & 1 deletion et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ def replaySingle(

if groupRank >= 0:
commDesc = f"{str(curComm.comms)}: NumElemsIn={curComm.inMsgSize}, NumElemsOut={curComm.outMsgSize}, Dtype={curComm.dtype}"
if curComm.comms == "all_to_allv":
if curComm.comms in ("all_to_all", "all_to_allv"):
commDesc += (
f", InSplit={curComm.inSplit}, OutSplit={curComm.outSplit}"
)
Expand Down