-
Notifications
You must be signed in to change notification settings - Fork 277
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c20387c
commit 347f522
Showing
3 changed files
with
272 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from functools import partial | ||
import jax | ||
import jax.numpy as jnp | ||
from jax import custom_vjp | ||
|
||
lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1 | ||
|
||
@partial(custom_vjp, nondiff_argnums=(4, 5)) | ||
def flash_attention(query, key, value, bias, causal, softmax_scale): | ||
out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale) | ||
return out | ||
|
||
|
||
def _mha_forward(query, key, value, bias, causal, softmax_scale): | ||
# Get the batch size, sequence lengths, number of heads, and hidden dimension | ||
batch_size, q_seq_len, num_heads, d_model = query.shape | ||
|
||
# Transpose the query, key, and value tensors | ||
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len] | ||
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len] | ||
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model] | ||
|
||
import neuronxcc.nki.language as nl | ||
from neuronxcc.nki.kernels.attention import flash_fwd | ||
seed = jnp.array([1]) | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads | ||
if (num_heads % 2) == 0 and (num_heads // 2 > 0): | ||
grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
else: | ||
grid = batch_size, num_heads | ||
|
||
if bias != None: | ||
assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
attn_output, lse = flash_fwd[grid]( | ||
q, | ||
k, | ||
v, | ||
seed, | ||
bias, | ||
use_causal_mask=causal, | ||
softmax_scale=softmax_scale, | ||
mixed_precision=True, | ||
dropout_p=0.0, | ||
) | ||
else: | ||
attn_output, lse = flash_fwd[grid]( | ||
q, | ||
k, | ||
v, | ||
seed, | ||
use_causal_mask=causal, | ||
softmax_scale=softmax_scale, | ||
mixed_precision=True, | ||
dropout_p=0.0, | ||
) | ||
# Transpose the output back to the original shape | ||
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model] | ||
|
||
return attn_output, (lse, attn_output, q, k, v, bias) | ||
|
||
|
||
def _mha_backward(causal, softmax_scale, res, d_attn_output): | ||
lse, o, q, k, v, bias = res | ||
batch_size, num_heads, d_model, seq_len = q.shape | ||
|
||
# Transpose the input tensors | ||
o = o.transpose(0, 2, 3, 1) | ||
dy = d_attn_output.transpose(0, 2, 3, 1) | ||
|
||
# Transpose v tensor | ||
v = jnp.transpose(v, axes=(0, 1, 3, 2)) | ||
seed = jnp.array([1]) | ||
|
||
from neuronxcc.nki.kernels.attention import flash_attn_bwd | ||
import neuronxcc.nki.language as nl | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads | ||
if (num_heads % 2) == 0 and (num_heads // 2 > 0): | ||
grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
else: | ||
grid = batch_size, num_heads | ||
|
||
if bias != None: | ||
assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
q, | ||
k, | ||
v, | ||
o, | ||
dy, | ||
lse, | ||
seed, | ||
bias, | ||
use_causal_mask=causal, | ||
mixed_precision=True, | ||
dropout_p=0.0, | ||
softmax_scale=softmax_scale, | ||
) | ||
else: | ||
d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
q, | ||
k, | ||
v, | ||
o, | ||
dy, | ||
lse, | ||
seed, | ||
use_causal_mask=causal, | ||
mixed_precision=True, | ||
dropout_p=0.0, | ||
softmax_scale=softmax_scale, | ||
) | ||
|
||
# Batch seq_len heads, head_dim | ||
# Transpose the gradients back to the original shape | ||
d_query = d_query.transpose(0, 3, 1, 2) | ||
d_key = d_key.transpose(0, 3, 1, 2) | ||
d_value = d_value.transpose(0, 3, 1, 2) | ||
|
||
return d_query, d_key, d_value, None | ||
|
||
|
||
flash_attention.defvjp(_mha_forward, _mha_backward) |
132 changes: 132 additions & 0 deletions
132
axlearn/common/flash_attention/neuron_attention_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# Copyright © 2024 Amazon Inc. | ||
"""Tests for Flash attention on Neuron. Tested on trn1 & trn2.""" | ||
import functools | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
from axlearn.common.flash_attention.neuron_attention import flash_attention | ||
from axlearn.common.flash_attention.utils import mha_reference | ||
|
||
|
||
if jax.default_backend() != "neuron": | ||
pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size,seq_len,num_heads,per_head_dim", | ||
[ | ||
(1, 2048, 1, 64), | ||
(2, 2048, 2, 64), | ||
(1, 2048, 1, 128), | ||
(2, 2048, 2, 128), | ||
(1, 2048, 8, 128), | ||
(2, 2048, 8, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("use_fwd", [True, False]) | ||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32]) | ||
def test_fwd_against_ref( | ||
batch_size: int, | ||
seq_len: int, | ||
num_heads: int, | ||
per_head_dim: int, | ||
use_fwd: bool, | ||
causal: bool, | ||
input_dtype: jnp.dtype, | ||
): | ||
sm_scale = 1.0 / (per_head_dim**0.5) | ||
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) | ||
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
|
||
bias = None | ||
segment_ids = None | ||
|
||
if use_fwd: | ||
|
||
@jax.jit | ||
def impl(q, k, v, bias): | ||
fn = functools.partial( | ||
flash_attention, | ||
causal=causal, | ||
softmax_scale=sm_scale, | ||
) | ||
out, _ = jax.vjp(fn, q, k, v, bias) | ||
return out | ||
|
||
else: | ||
impl = functools.partial( | ||
flash_attention, | ||
causal=causal, | ||
softmax_scale=sm_scale, | ||
) | ||
|
||
o = impl(q, k, v, bias) | ||
o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale) | ||
chex.assert_trees_all_close(o, o_ref, atol=0.05) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size,num_heads,seq_len,per_head_dim", | ||
[ | ||
(1, 1, 2048, 64), | ||
(2, 2, 2048, 64), | ||
(1, 1, 2048, 128), | ||
(2, 2, 2048, 128), | ||
(1, 8, 2048, 128), | ||
(2, 8, 2048, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) | ||
def test_bwd_against_ref( | ||
batch_size: int, | ||
num_heads: int, | ||
seq_len: int, | ||
per_head_dim: int, | ||
causal: bool, | ||
input_dtype: jnp.dtype, | ||
): | ||
sm_scale = 1.0 / (per_head_dim**0.5) | ||
q = jax.random.normal( | ||
jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
) | ||
k = jax.random.normal( | ||
jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
) | ||
v = jax.random.normal( | ||
jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
) | ||
|
||
bias = None | ||
segment_ids = None | ||
|
||
def fn(q, k, v, bias): | ||
return flash_attention( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=sm_scale, | ||
).sum() | ||
|
||
def ref_fn(q, k, v, bias, segment_ids): | ||
return mha_reference( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
segment_ids, | ||
causal=causal, | ||
softmax_scale=sm_scale, | ||
).sum() | ||
|
||
jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias) | ||
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids) | ||
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters