diff --git a/train/comms/pt/comms_utils.py b/train/comms/pt/comms_utils.py index 614e9ce6..76d3fdaf 100644 --- a/train/comms/pt/comms_utils.py +++ b/train/comms/pt/comms_utils.py @@ -1769,6 +1769,12 @@ def readArgs(self, parser: ArgumentParser) -> None: default=False, help="Toggle to initialize progress group immediately during init_process_group call by passing device_id, see https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group", ) + parser.add_argument( + "--enable-torch-nccl-timing", + action="store_true", + default=False, + help="Enable recording start-events for all ProcessGroupNCCL collectives, and compute accurate collective timing per-collective, may have significant performance impact", + ) pass @abstractmethod @@ -1831,6 +1837,15 @@ def checkArgs(self, args: Namespace) -> None: else: os.environ["MASTER_PORT"] = args.master_port + # Enabling the "TORCH_NCCL_ENABLE_TIMING" setting can lead to performance regression in benchmark results. + # This setting is used to record start-events for all ProcessGroupNCCL collectives, which allows for accurate timing of each collective operation. + # However, the it should be used with caution when performance is a critical factor in the benchmark results, since this will add one extra function call + # to CUDA kernel start + if args.enable_torch_nccl_timing: + os.environ["TORCH_NCCL_ENABLE_TIMING"] = "1" + else: + os.environ["TORCH_NCCL_ENABLE_TIMING"] = "0" + class paramCommsBench(ParamCommsBenchMixin, ParamCommsBenchBase): def __init__(self, supportedNwstacks: list[str] = None) -> None: