Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements FlashDecoding with Sparsity Support #899

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
528 changes: 334 additions & 194 deletions axlearn/common/flash_attention/gpu_attention_benchmark.py

Large diffs are not rendered by default.

135 changes: 131 additions & 4 deletions axlearn/common/flash_attention/gpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized

from axlearn.common.attention_bias import sliding_window_causal_mask
from axlearn.common.flash_attention.gpu_attention import (
cudnn_dot_product_attention,
flash_attention,
)
from axlearn.common.flash_attention.utils import mha_reference
from axlearn.common.flash_attention.gpu_decoding import NEG_INF, flash_decoding
from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference
from axlearn.common.test_utils import TestCase

if jax.default_backend() != "gpu":
pytest.skip(reason="Incompatible hardware", allow_module_level=True)
Expand Down Expand Up @@ -92,9 +96,132 @@ def impl(q, k, v, bias, segment_ids):
chex.assert_trees_all_close(o, o_ref, atol=0.07)


# We test the flash_attention against the reference mha_reference.
# The outputs should be close in both fp16 and fp32, with a relaxed bound due
# to the numerical difference during operations.
class FlashDecodingTest(TestCase):
"""Tests FlashDecoding."""

@parameterized.product(
[
dict(zip(["batch_size", "seq_len", "num_heads", "per_head_dim"], args))
for args in [
(1, 1024, 32, 64),
(1, 444, 16, 64),
(8, 1596, 48, 128),
(8, 4044, 64, 128),
]
],
softmax_scale=[1.0, 0.83],
attention_bias_type=["2d", "4d", None],
input_dtype=[jnp.float32, jnp.float16],
padding=[0, 111],
kv_head_factor=[1, 4, 8],
)
def test_decode_against_ref(
self,
batch_size: int,
seq_len: int,
num_heads: int,
per_head_dim: int,
softmax_scale: float,
attention_bias_type: Literal["2d", "4d", None],
input_dtype: jnp.dtype,
padding: int,
kv_head_factor: int,
):
self.assertEqual(num_heads % kv_head_factor, 0)
assert num_heads % kv_head_factor == 0
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4)
q = jax.random.normal(k1, (batch_size, 1, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(
k2,
(batch_size, seq_len + padding, num_heads // kv_head_factor, per_head_dim),
dtype=input_dtype,
)
v = jax.random.normal(
k3,
(batch_size, seq_len + padding, num_heads // kv_head_factor, per_head_dim),
dtype=input_dtype,
)

if attention_bias_type == "4d":
bias = jax.random.normal(
k4, (batch_size, num_heads, 1, seq_len + padding), dtype=input_dtype
)
elif attention_bias_type == "2d":
bias = jax.random.normal(k4, (1, 1, 1, seq_len + padding), dtype=input_dtype)
else:
bias = None

impl = functools.partial(flash_decoding, softmax_scale=softmax_scale, kv_seq_len=seq_len)

o = impl(q, k, v, bias)
if bias is not None:
bias = bias[:, :, :, :seq_len]
o_ref = mha_reference(
q,
_repeat_kv_heads(num_heads, k[:, :seq_len]),
_repeat_kv_heads(num_heads, v[:, :seq_len]),
bias,
None,
causal=False,
softmax_scale=softmax_scale,
)
self.assertGreaterEqual(jnp.median(jnp.abs(o_ref)).item(), 0.25)
if input_dtype is jnp.float32:
self.assertNestedAllClose(o, o_ref, rtol=0.01, atol=0.01)
else:
self.assertNestedAllClose(o, o_ref, rtol=0.05, atol=0.05)

@parameterized.product(
[
dict(zip(["batch_size", "seq_len", "num_heads", "per_head_dim"], args))
for args in [
(1, 1024 * 16, 8, 128),
(1, 1128, 32, 64),
(8, 305, 48, 128),
(8, 4042, 64, 128),
]
],
input_dtype=[jnp.float32, jnp.float16],
padding=[0, 123],
window_len=[16, 127],
)
def test_decode_sliding_window(
self,
batch_size: int,
seq_len: int,
num_heads: int,
per_head_dim: int,
input_dtype: jnp.dtype,
padding: int,
window_len: int,
):
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(42), 3)
q = jax.random.normal(k1, (batch_size, 1, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(
k2, (batch_size, seq_len + padding, num_heads, per_head_dim), dtype=input_dtype
)
v = jax.random.normal(
k3, (batch_size, seq_len + padding, num_heads, per_head_dim), dtype=input_dtype
)

o = flash_decoding(
q,
k,
v,
mask_fn=sliding_window_causal_mask(window_len),
kv_seq_len=seq_len,
)
mask = jnp.zeros((1, 1, 1, seq_len), dtype=input_dtype)
mask = mask.at[:, :, :, : -window_len - 1].set(NEG_INF)
o_ref = mha_reference(q, k[:, :seq_len], v[:, :seq_len], mask, None, causal=False)

self.assertGreaterEqual(jnp.median(jnp.abs(o_ref)).item(), 0.25)
if input_dtype is jnp.float32:
self.assertNestedAllClose(o, o_ref, atol=0.01)
else:
self.assertNestedAllClose(o, o_ref, atol=0.025)


@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,per_head_dim",
[
Expand Down
Loading
Loading