From 8dc180bff85ead8f97acd43a36feed7bb8f870cc Mon Sep 17 00:00:00 2001 From: Tan Hoang Date: Wed, 2 Apr 2025 03:43:27 -0700 Subject: [PATCH] add explicit setting to disable torch nccl timing Summary: Add a param in the base parser to set `torch_nccl_enable_timing` variable to `False` by default, and only set it to true if user needed. This value is used only on flight-recorder (for debugging purpose), and significantly affect performance of benchmarking in blocking mode (~30% for small-mid message sizes) if enable Reviewed By: kingchc Differential Revision: D72240605 --- train/comms/pt/comms_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/train/comms/pt/comms_utils.py b/train/comms/pt/comms_utils.py index 614e9ce6..76d3fdaf 100644 --- a/train/comms/pt/comms_utils.py +++ b/train/comms/pt/comms_utils.py @@ -1769,6 +1769,12 @@ def readArgs(self, parser: ArgumentParser) -> None: default=False, help="Toggle to initialize progress group immediately during init_process_group call by passing device_id, see https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group", ) + parser.add_argument( + "--enable-torch-nccl-timing", + action="store_true", + default=False, + help="Enable recording start-events for all ProcessGroupNCCL collectives, and compute accurate collective timing per-collective, may have significant performance impact", + ) pass @abstractmethod @@ -1831,6 +1837,15 @@ def checkArgs(self, args: Namespace) -> None: else: os.environ["MASTER_PORT"] = args.master_port + # Enabling the "TORCH_NCCL_ENABLE_TIMING" setting can lead to performance regression in benchmark results. + # This setting is used to record start-events for all ProcessGroupNCCL collectives, which allows for accurate timing of each collective operation. + # However, the it should be used with caution when performance is a critical factor in the benchmark results, since this will add one extra function call + # to CUDA kernel start + if args.enable_torch_nccl_timing: + os.environ["TORCH_NCCL_ENABLE_TIMING"] = "1" + else: + os.environ["TORCH_NCCL_ENABLE_TIMING"] = "0" + class paramCommsBench(ParamCommsBenchMixin, ParamCommsBenchBase): def __init__(self, supportedNwstacks: list[str] = None) -> None: