diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 5372018ae87..a2186a34e22 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -2,6 +2,8 @@ # # See LICENSE for license information. +from contextlib import contextmanager +import math import os import pytest import jax @@ -26,6 +28,7 @@ CPStrategy, ReorderStrategy, ) +from transformer_engine.jax.sharding import global_shard_guard, MeshResource DTYPES = [jnp.bfloat16] @@ -667,3 +670,154 @@ def test(self, cp_size, shape, qkv_format, reorder_strategy): inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim) assert jnp.array_equal(inversed, ref) + + +class TestMismatchingQKVSharding: + + @staticmethod + @contextmanager + def _mesh_ctx(): + mesh_resource = MeshResource( + dp_resource=None, + tp_resource=None, + fsdp_resource="fsdp", + pp_resource=None, + cp_resource="context", + ) + mesh = jax.make_mesh( + axis_shapes=(4, 1), + axis_names=( + "fsdp", + "context", + ), + axis_types=( + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Auto, + ), + ) + with global_shard_guard(mesh_resource), mesh: + yield mesh, mesh_resource + + def _generate_inputs( + self, + batch_size, + seq_len, + num_heads, + head_dim, + mesh, + query_spec, + key_spec, + value_spec, + q_segment_ids_spec, + kv_segment_ids_spec, + ): + query = jax.random.uniform( + jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, head_dim), dtype=jnp.bfloat16 + ) + key = jax.random.uniform( + jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, head_dim), dtype=jnp.bfloat16 + ) + value = jax.random.uniform( + jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, head_dim), dtype=jnp.bfloat16 + ) + q_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + kv_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + + def make_sharding(spec): + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec)) + + query = jax.device_put(query, make_sharding(query_spec)) + key = jax.device_put(key, make_sharding(key_spec)) + value = jax.device_put(value, make_sharding(value_spec)) + q_segment_ids = jax.device_put(q_segment_ids, make_sharding(q_segment_ids_spec)) + kv_segment_ids = jax.device_put(kv_segment_ids, make_sharding(kv_segment_ids_spec)) + return query, key, value, q_segment_ids, kv_segment_ids + + def _attn_impl(self, query, key, value, q_segment_ids, kv_segment_ids): + from transformer_engine.jax.attention import SequenceDescriptor + from transformer_engine.jax.flax import DotProductAttention + + head_dim = query.shape[-1] + qkv_layout = "BSHD_BSHD_BSHD" + segment_ids = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=(q_segment_ids, kv_segment_ids), + ) + dpa_layer = DotProductAttention( + head_dim=head_dim, + num_attention_heads=query.shape[1], + num_gqa_groups=key.shape[1], + attn_mask_type="no_mask", + attn_bias_type="no_bias", + attention_dropout=0, + dropout_rng_name="aqt", + dtype=jax.numpy.bfloat16, + float32_logits=True, + qkv_layout=qkv_layout, + scale_factor=1 / math.sqrt(query.shape[-1]), + transpose_batch_sequence=False, + window_size=None, + context_parallel_causal_load_balanced=False, + context_parallel_axis="context", + context_parallel_strategy="all_gather", # "all_gather" or "ring" + max_segments_per_seq=1, + ) + + return dpa_layer.apply( + {}, + query.astype(jax.numpy.bfloat16), + key.astype(jax.numpy.bfloat16), + value.astype(jax.numpy.bfloat16), + segment_ids, + ) + + def _impl( + self, + mesh, + query_spec, + key_spec, + value_spec, + q_segment_ids_spec, + kv_segment_ids_spec, + expected_error_message, + ): + did_error = False + try: + query, key, value, q_segment_ids, kv_segment_ids = self._generate_inputs( + batch_size=4, + seq_len=8192, + num_heads=30, + head_dim=128, + mesh=mesh, + query_spec=query_spec, + key_spec=key_spec, + value_spec=value_spec, + q_segment_ids_spec=q_segment_ids_spec, + kv_segment_ids_spec=kv_segment_ids_spec, + ) + + out = jax.jit(self._attn_impl)(query, key, value, q_segment_ids, kv_segment_ids) + except jax.errors.JaxRuntimeError as e: + did_error = True + assert expected_error_message in str( + e + ), f"Expected error message '{expected_error_message}' not found in '{str(e)}'" + assert ( + did_error + ), "Expected an error due to mismatching QKV sharding specs, but no error was raised." + + def test_mismatching_qkv_sharding_separate_qkv(self): + with self._mesh_ctx() as (mesh, mesh_resource): + self._impl( + mesh=mesh, + query_spec=(mesh_resource.fsdp_resource, None, None, None), + key_spec=(mesh_resource.fsdp_resource, None, None, None), + # Value is replicated and mismatching + value_spec=(None, None, None, None), + q_segment_ids_spec=(mesh_resource.fsdp_resource, None), + kv_segment_ids_spec=(mesh_resource.fsdp_resource, None), + expected_error_message=( + "Q, K, and V sharding specs must be identical but received" + " q_spec=PartitionSpec('fsdp', None, None, None), k_spec=PartitionSpec('fsdp'," + " None, None, None), v_spec=PartitionSpec(None, None, None, None)" + ), + ) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index f0778bfd292..c7546b869ec 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -687,6 +687,16 @@ def partition(config, mesh, arg_infos, result_infos): arg_shardings[-2] = arg_shardings[-4] arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + + if config.qkv_layout.is_separate(): + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + assert q_spec == k_spec == v_spec, ( + f"Q, K, and V sharding specs must be identical but received {q_spec=}, {k_spec=}," + f" {v_spec=}" + ) + impl = partial(FusedAttnFwdPrimitive.impl, config=config) return mesh, impl, out_shardings, arg_shardings