diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 137fa480ddb..1593f50f042 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -12,7 +12,7 @@ from transformer_engine.jax.sharding import MeshResource -from utils import assert_allclose, is_devices_enough +from utils import assert_allclose, is_devices_enough, is_devices_equal def generate_configs(): @@ -49,7 +49,10 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - if is_devices_enough(ndev): + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # For e.g., if ndev=8 and num GPUs is 8, thent hose combinations will be picked. + # However, if ndev=4, but num GPUs is 8, then those combinations will not be picked. To pick such a combination, one can set CUDA_VISIBLE_DEVICES=0,1,2,3. + if is_devices_equal(ndev): # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) if cp != 1: configsL1.append( diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 7194e387c73..39307075024 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -47,6 +47,13 @@ def is_devices_enough(required): return len(jax.devices()) >= required +def is_devices_equal(required): + """ + Check if the available GPUs is exactly equal + """ + return len(jax.devices()) == required + + def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: # Generate broadcast dims for drop_path. drop_path_shape = list(range(0, len(shape)))