Skip to content

Commit

Permalink
Flash Attention for Neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 11, 2024
1 parent c20387c commit eab90f9
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 1 deletion.
129 changes: 129 additions & 0 deletions axlearn/common/flash_attention/neuron_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from absl import logging
from functools import partial
import jax
import jax.numpy as jnp
import jax.numpy as jnp
from jax import custom_vjp
import os

lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1

@partial(custom_vjp, nondiff_argnums=(4, 5))
def flash_attention(query, key, value, bias, causal, softmax_scale):
# NOTE : Merge with upstream. Old code supports both 2d and 4d bias but upstream code only supports 4d.
# We no longer need 2d logit_bias but should sync how we merge this check with upstream.
out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale)
return out


def _mha_forward(query, key, value, bias, causal, softmax_scale):
# Get the batch size, sequence lengths, number of heads, and hidden dimension
batch_size, q_seq_len, num_heads, d_model = query.shape

# Transpose the query, key, and value tensors
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len]
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len]
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model]

import neuronxcc.nki.language as nl
from neuronxcc.nki.kernels.attention import flash_fwd
seed = jnp.array([1])

# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads
if (num_heads % 2) == 0 and (num_heads // 2 > 0):
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias != None:
assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
attn_output, lse = flash_fwd[grid](
q,
k,
v,
seed,
bias,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=0.0,
)
else:
attn_output, lse = flash_fwd[grid](
q,
k,
v,
seed,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=0.0,
)
# Transpose the output back to the original shape
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model]

return attn_output, (lse, attn_output, q, k, v, bias)


def _mha_backward(causal, softmax_scale, res, d_attn_output):
lse, o, q, k, v, bias = res
batch_size, num_heads, d_model, seq_len = q.shape

# Transpose the input tensors
o = o.transpose(0, 2, 3, 1)
dy = d_attn_output.transpose(0, 2, 3, 1)

# Transpose v tensor
v = jnp.transpose(v, axes=(0, 1, 3, 2))
seed = jnp.array([1])

from neuronxcc.nki.kernels.attention import flash_attn_bwd
import neuronxcc.nki.language as nl

# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads
if (num_heads % 2) == 0 and (num_heads // 2 > 0):
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias != None:
assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
o,
dy,
lse,
seed,
bias,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=0.0,
softmax_scale=softmax_scale,
)
else:
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
o,
dy,
lse,
seed,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=0.0,
softmax_scale=softmax_scale,
)

# Batch seq_len heads, head_dim
# Transpose the gradients back to the original shape
d_query = d_query.transpose(0, 3, 1, 2)
d_key = d_key.transpose(0, 3, 1, 2)
d_value = d_value.transpose(0, 3, 1, 2)

return d_query, d_key, d_value, None


flash_attention.defvjp(_mha_forward, _mha_backward)
132 changes: 132 additions & 0 deletions axlearn/common/flash_attention/neuron_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright © 2024 Amazon Inc.
"""Tests for Flash attention on Neuron. Tested on trn1 & trn2."""
import functools

import chex
import jax
import jax.numpy as jnp
import pytest

from axlearn.common.flash_attention.neuron_attention import flash_attention
from axlearn.common.flash_attention.utils import mha_reference


if jax.default_backend() != "neuron":
pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.")


@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
[
(1, 2048, 1, 64),
(2, 2048, 2, 64),
(1, 2048, 1, 128),
(2, 2048, 2, 128),
(1, 2048, 8, 128),
(2, 2048, 8, 128),
],
)
@pytest.mark.parametrize("use_fwd", [True, False])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32])
def test_fwd_against_ref(
batch_size: int,
seq_len: int,
num_heads: int,
per_head_dim: int,
use_fwd: bool,
causal: bool,
input_dtype: jnp.dtype,
):
sm_scale = 1.0 / (per_head_dim**0.5)
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)

bias = None
segment_ids = None

if use_fwd:

@jax.jit
def impl(q, k, v, bias):
fn = functools.partial(
flash_attention,
causal=causal,
softmax_scale=sm_scale,
)
out, _ = jax.vjp(fn, q, k, v, bias)
return out

else:
impl = functools.partial(
flash_attention,
causal=causal,
softmax_scale=sm_scale,
)

o = impl(q, k, v, bias)
o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale)
chex.assert_trees_all_close(o, o_ref, atol=0.05)


@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,per_head_dim",
[
(1, 1, 2048, 64),
(2, 2, 2048, 64),
(1, 1, 2048, 128),
(2, 2, 2048, 128),
(1, 8, 2048, 128),
(2, 8, 2048, 128),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32])
def test_bwd_against_ref(
batch_size: int,
num_heads: int,
seq_len: int,
per_head_dim: int,
causal: bool,
input_dtype: jnp.dtype,
):
sm_scale = 1.0 / (per_head_dim**0.5)
q = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)
k = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)
v = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)

bias = None
segment_ids = None

def fn(q, k, v, bias):
return flash_attention(
q,
k,
v,
bias,
causal=causal,
softmax_scale=sm_scale,
).sum()

def ref_fn(q, k, v, bias, segment_ids):
return mha_reference(
q,
k,
v,
bias,
segment_ids,
causal=causal,
softmax_scale=sm_scale,
).sum()

jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias)
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids)
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07)
17 changes: 16 additions & 1 deletion axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def mha_reference(


def flash_attention_implementation(
backend: Literal["cpu", "tpu", "gpu", "xla"],
backend: Literal["cpu", "tpu", "gpu", "xla", "neuron"],
*,
mask: Optional[MaskFn] = None,
softmax_scale: float,
Expand Down Expand Up @@ -159,6 +159,21 @@ def jit_attn(query, key, value, bias, segment_ids):

return jit_attn

elif backend == "neuron":
from axlearn.common.flash_attention.neuron_attention import (
flash_attention as neuron_flash_attention,
)

# shard_map-decorated function needs to be jitted.
@jax.jit
def jit_attn(query, key, value, bias, segment_ids):
if segment_ids != None:
raise Exception("Sequence Packing is not supported on Neuron backend")
return neuron_flash_attention(
query, key, value, bias, causal, softmax_scale)

return jit_attn

elif backend in ("cpu", "xla"):
if backend == "cpu":
logging.warning("Flash attention CPU backend is for testing only.")
Expand Down

0 comments on commit eab90f9

Please sign in to comment.