Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# See LICENSE for license information.

from contextlib import contextmanager
import math
import os
import pytest
import jax
Expand All @@ -26,6 +28,7 @@
CPStrategy,
ReorderStrategy,
)
from transformer_engine.jax.sharding import global_shard_guard, MeshResource


DTYPES = [jnp.bfloat16]
Expand Down Expand Up @@ -667,3 +670,154 @@ def test(self, cp_size, shape, qkv_format, reorder_strategy):
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)

assert jnp.array_equal(inversed, ref)


class TestMismatchingQKVSharding:

@staticmethod
@contextmanager
def _mesh_ctx():
mesh_resource = MeshResource(
dp_resource=None,
tp_resource=None,
fsdp_resource="fsdp",
pp_resource=None,
cp_resource="context",
)
mesh = jax.make_mesh(
axis_shapes=(4, 1),
axis_names=(
"fsdp",
"context",
),
axis_types=(
jax.sharding.AxisType.Auto,
jax.sharding.AxisType.Auto,
),
)
with global_shard_guard(mesh_resource), mesh:
yield mesh, mesh_resource

def _generate_inputs(
self,
batch_size,
seq_len,
num_heads,
head_dim,
mesh,
query_spec,
key_spec,
value_spec,
q_segment_ids_spec,
kv_segment_ids_spec,
):
query = jax.random.uniform(
jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, head_dim), dtype=jnp.bfloat16
)
key = jax.random.uniform(
jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, head_dim), dtype=jnp.bfloat16
)
value = jax.random.uniform(
jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, head_dim), dtype=jnp.bfloat16
)
q_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
kv_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)

def make_sharding(spec):
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))

query = jax.device_put(query, make_sharding(query_spec))
key = jax.device_put(key, make_sharding(key_spec))
value = jax.device_put(value, make_sharding(value_spec))
q_segment_ids = jax.device_put(q_segment_ids, make_sharding(q_segment_ids_spec))
kv_segment_ids = jax.device_put(kv_segment_ids, make_sharding(kv_segment_ids_spec))
return query, key, value, q_segment_ids, kv_segment_ids

def _attn_impl(self, query, key, value, q_segment_ids, kv_segment_ids):
from transformer_engine.jax.attention import SequenceDescriptor
from transformer_engine.jax.flax import DotProductAttention

head_dim = query.shape[-1]
qkv_layout = "BSHD_BSHD_BSHD"
segment_ids = SequenceDescriptor.from_segment_ids_and_pos(
segment_ids=(q_segment_ids, kv_segment_ids),
)
dpa_layer = DotProductAttention(
head_dim=head_dim,
num_attention_heads=query.shape[1],
num_gqa_groups=key.shape[1],
attn_mask_type="no_mask",
attn_bias_type="no_bias",
attention_dropout=0,
dropout_rng_name="aqt",
dtype=jax.numpy.bfloat16,
float32_logits=True,
qkv_layout=qkv_layout,
scale_factor=1 / math.sqrt(query.shape[-1]),
transpose_batch_sequence=False,
window_size=None,
context_parallel_causal_load_balanced=False,
context_parallel_axis="context",
context_parallel_strategy="all_gather", # "all_gather" or "ring"
max_segments_per_seq=1,
)

return dpa_layer.apply(
{},
query.astype(jax.numpy.bfloat16),
key.astype(jax.numpy.bfloat16),
value.astype(jax.numpy.bfloat16),
segment_ids,
)

def _impl(
self,
mesh,
query_spec,
key_spec,
value_spec,
q_segment_ids_spec,
kv_segment_ids_spec,
expected_error_message,
):
did_error = False
try:
query, key, value, q_segment_ids, kv_segment_ids = self._generate_inputs(
batch_size=4,
seq_len=8192,
num_heads=30,
head_dim=128,
mesh=mesh,
query_spec=query_spec,
key_spec=key_spec,
value_spec=value_spec,
q_segment_ids_spec=q_segment_ids_spec,
kv_segment_ids_spec=kv_segment_ids_spec,
)

out = jax.jit(self._attn_impl)(query, key, value, q_segment_ids, kv_segment_ids)
except jax.errors.JaxRuntimeError as e:
did_error = True
assert expected_error_message in str(
e
), f"Expected error message '{expected_error_message}' not found in '{str(e)}'"
assert (
did_error
), "Expected an error due to mismatching QKV sharding specs, but no error was raised."

def test_mismatching_qkv_sharding_separate_qkv(self):
with self._mesh_ctx() as (mesh, mesh_resource):
self._impl(
mesh=mesh,
query_spec=(mesh_resource.fsdp_resource, None, None, None),
key_spec=(mesh_resource.fsdp_resource, None, None, None),
# Value is replicated and mismatching
value_spec=(None, None, None, None),
q_segment_ids_spec=(mesh_resource.fsdp_resource, None),
kv_segment_ids_spec=(mesh_resource.fsdp_resource, None),
expected_error_message=(
"Q, K, and V sharding specs must be identical but received"
" q_spec=PartitionSpec('fsdp', None, None, None), k_spec=PartitionSpec('fsdp',"
" None, None, None), v_spec=PartitionSpec(None, None, None, None)"
),
)
10 changes: 10 additions & 0 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,16 @@ def partition(config, mesh, arg_infos, result_infos):
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)

if config.qkv_layout.is_separate():
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KshitijLakhani does this assertion make sense? For the case I'm testing above, where the batch dimension should always have the same sharding, I think it does make sense. But I'm not yet familiar enough with other parallelism techniques like CP to know if this assertion is valid on the non-batch axes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can think of MQA/GQA needing different sharding for Q and K along the head dimension .
For e.g. if we have 8 query heads and 1 k/v head then we could have a different PartitionSpec for query (let's say split across 4 devices so that 2 query heads per device) and a different PartitionSpec for key (repeating across devices, so None) so I would not be as restrictive.

I think along the batch might be fine and probably even along seq dim as I do not think any of the CP strategies may require a different PartitionSpec for QKV. Should be okay along the hidden dimension, too

cc: @mgoldfarb-nvidia @huanghua1994 @mingxu1067 to chime in

assert q_spec == k_spec == v_spec, (
f"Q, K, and V sharding specs must be identical but received {q_spec=}, {k_spec=},"
f" {v_spec=}"
)

impl = partial(FusedAttnFwdPrimitive.impl, config=config)
return mesh, impl, out_shardings, arg_shardings

Expand Down
Loading