Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions train/comms/pt/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1769,6 +1769,12 @@ def readArgs(self, parser: ArgumentParser) -> None:
default=False,
help="Toggle to initialize progress group immediately during init_process_group call by passing device_id, see https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group",
)
parser.add_argument(
"--enable-torch-nccl-timing",
action="store_true",
default=False,
help="Enable recording start-events for all ProcessGroupNCCL collectives, and compute accurate collective timing per-collective, may have significant performance impact",
)
pass

@abstractmethod
Expand Down Expand Up @@ -1831,6 +1837,15 @@ def checkArgs(self, args: Namespace) -> None:
else:
os.environ["MASTER_PORT"] = args.master_port

# Enabling the "TORCH_NCCL_ENABLE_TIMING" setting can lead to performance regression in benchmark results.
# This setting is used to record start-events for all ProcessGroupNCCL collectives, which allows for accurate timing of each collective operation.
# However, the it should be used with caution when performance is a critical factor in the benchmark results, since this will add one extra function call
# to CUDA kernel start
if args.enable_torch_nccl_timing:
os.environ["TORCH_NCCL_ENABLE_TIMING"] = "1"
else:
os.environ["TORCH_NCCL_ENABLE_TIMING"] = "0"


class paramCommsBench(ParamCommsBenchMixin, ParamCommsBenchBase):
def __init__(self, supportedNwstacks: list[str] = None) -> None:
Expand Down