Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 122 additions & 2 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs_for_attn,
generate_collectives_count,
)
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
from utils import pytest_parametrize_wrapper
from test_fused_attn import (
FusedAttnRunner,
BiasShape,
SeqDescFormat,
_ScoreModSoftcap,
_has_cudnn_frontend_python,
_reference_attention,
_require_cudnn_frontend_score_mod,
)
from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.jax import autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
Expand All @@ -25,6 +36,7 @@
inverse_reorder_causal_load_balancing,
CPStrategy,
ReorderStrategy,
fused_attn,
)


Expand Down Expand Up @@ -272,6 +284,114 @@ def test_cross_attn(
runner.test_backward()


DISTRIBUTED_SCORE_MOD_DATA_SHAPES = {
"L0": [],
"L1": [(4, 16, 4, 64)],
"L2": [(4, 16, 4, 64)],
}


@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required")
class TestDistributedScoreModSelfAttn:
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SCORE_MOD_DATA_SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
def test_softcap_score_mod_with_aux_params_backward(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
):
_require_cudnn_frontend_score_mod()
batch, seqlen, num_heads, head_dim = data_shape
dp_axis = mesh_resource.dp_resource
tp_axis = mesh_resource.tpsp_resource

if dp_axis is not None:
dp_size = mesh_shape[mesh_axes.index(dp_axis)]
if batch % dp_size != 0:
pytest.skip(f"{batch=} must be divisible by {dp_size=}")
if tp_axis is not None:
tp_size = mesh_shape[mesh_axes.index(tp_axis)]
if num_heads % tp_size != 0:
pytest.skip(f"{num_heads=} must be divisible by {tp_size=}")

devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
qkv_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, tp_axis, None))

key = random.PRNGKey(2025)
q_key, k_key, v_key, dout_key = random.split(key, 4)
query = (0.125 * random.normal(q_key, data_shape, dtype=dtype)).astype(dtype)
key_tensor = (0.125 * random.normal(k_key, data_shape, dtype=dtype)).astype(dtype)
value = (0.125 * random.normal(v_key, data_shape, dtype=dtype)).astype(dtype)
doutput = random.normal(dout_key, data_shape, dtype=dtype)

scaling_factor = head_dim**-0.5
softcap = 0.8
softcap_score_mod = _ScoreModSoftcap()

def score_mod_loss(q, k, v, dout):
out = fused_attn(
(q, k, v),
None,
None,
None,
AttnBiasType.NO_BIAS,
AttnMaskType.NO_MASK,
QKVLayout.BSHD_BSHD_BSHD,
AttnSoftmaxType.VANILLA_SOFTMAX,
scaling_factor,
0.0,
True,
score_mod=softcap_score_mod.forward,
score_mod_bprop=softcap_score_mod.backward,
score_mod_tensors={"softcap": softcap},
score_mod_bprop_tensors={"softcap": softcap},
)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

def ref_loss(q, k, v, dout):
out = _reference_attention(q, k, v, scaling_factor, softcap=softcap)
loss = jnp.sum(out.astype(jnp.float32) * dout.astype(jnp.float32))
return loss, out

jitted_score_mod = jax.jit(
jax.value_and_grad(score_mod_loss, argnums=(0, 1, 2), has_aux=True),
in_shardings=(
qkv_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
),
out_shardings=((None, qkv_sharding), (qkv_sharding, qkv_sharding, qkv_sharding)),
)
jitted_ref = jax.jit(jax.value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True))

sharded_args = (
jax.device_put(query, qkv_sharding),
jax.device_put(key_tensor, qkv_sharding),
jax.device_put(value, qkv_sharding),
jax.device_put(doutput, qkv_sharding),
)
with mesh, autocast(mesh_resource=mesh_resource):
(score_mod_value, score_mod_out), score_mod_grads = jitted_score_mod(*sharded_args)
(ref_value, ref_out), ref_grads = jitted_ref(query, key_tensor, value, doutput)

assert score_mod_out.sharding == qkv_sharding
for grad in score_mod_grads:
assert grad.sharding == qkv_sharding

assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2)
assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2)
for grad, ref_grad in zip(score_mod_grads, ref_grads):
assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2)


DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
Expand Down
Loading
Loading