From 6857b287798842c69ab082295e0dd15fb48a7182 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 24 Dec 2025 19:19:31 +0000 Subject: [PATCH 1/2] Pick a leaner set of combinations for TE JAX CP attn tests such that only those cp,dp,tp combinations are picked where cp*dp*tp is equal to num gpus Signed-off-by: Kshitij Lakhani --- tests/jax/distributed_test_base.py | 7 +++++-- tests/jax/utils.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 137fa480ddb..f86f81ec48f 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -12,7 +12,7 @@ from transformer_engine.jax.sharding import MeshResource -from utils import assert_allclose, is_devices_enough +from utils import assert_allclose, is_devices_enough, is_devices_equal def generate_configs(): @@ -49,7 +49,10 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - if is_devices_enough(ndev): + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # For e.g., if ndev=8 and num GPUs is 8, thent hose combinations will be picked. + # However, if ndev=4, but num GPUs is 8, then those combinations will not be picked. To pick such a combination, one can set CUDA_VISIBLE_DEVICES=0,1,2,3. + if is_devices_equal(ndev): # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) if cp != 1: configsL1.append( diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 7194e387c73..c3311395a0d 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -46,6 +46,12 @@ def is_devices_enough(required): """ return len(jax.devices()) >= required +def is_devices_equal(required): + """ + Check if the available GPUs is exactly equal + """ + return len(jax.devices()) == required + def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: # Generate broadcast dims for drop_path. From d2f9634c30911709867cc51ffe5d4626ed7dc9d7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Dec 2025 19:28:05 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/distributed_test_base.py | 2 +- tests/jax/utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index f86f81ec48f..1593f50f042 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -49,7 +49,7 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. # For e.g., if ndev=8 and num GPUs is 8, thent hose combinations will be picked. # However, if ndev=4, but num GPUs is 8, then those combinations will not be picked. To pick such a combination, one can set CUDA_VISIBLE_DEVICES=0,1,2,3. if is_devices_equal(ndev): diff --git a/tests/jax/utils.py b/tests/jax/utils.py index c3311395a0d..39307075024 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -46,6 +46,7 @@ def is_devices_enough(required): """ return len(jax.devices()) >= required + def is_devices_equal(required): """ Check if the available GPUs is exactly equal