Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize TPU Flash Attention (400x speed-up on 32k long context) #845

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions axlearn/common/flash_attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
make_segment_mask,
)
from axlearn.common.config import config_class
from axlearn.common.flash_attention import tpu_attention
from axlearn.common.flash_attention.utils import (
MultiHeadAttentionImpl,
flash_attention_implementation,
Expand Down Expand Up @@ -169,10 +170,6 @@ def _compute_attention(
cfg = self.config
backend = self._backend()

# Repeats key/value heads dim if necessary.
k_proj = self._repeat_kv_heads(k_proj)
v_proj = self._repeat_kv_heads(v_proj)

batch, target_len, num_heads, _ = q_proj.shape
_, source_len, _, _ = k_proj.shape

Expand Down Expand Up @@ -228,7 +225,18 @@ def _compute_attention(
f"{k_proj.shape[1]} for correctly supported GPU flash attention usage."
)

if backend == "tpu":
if backend == "cpu" and not tpu_attention.check_tpu_splash_attention(
query=q_proj,
key=k_proj,
has_mask=bool(cfg.mask),
segment_ids=segment_ids,
has_bias=(attention_logit_biases is not None),
):
backend = "xla"

if backend in ("tpu", "cpu"):
# Splash attention needs to know sliding_window_size.
mask_fn = cfg.mask
assert q_proj.shape[1] % cfg.tpu_block_size == 0, (
f"Target seq len {q_proj.shape[1]} must be "
f"divisible by block size {cfg.tpu_block_size}."
Expand Down Expand Up @@ -263,6 +271,12 @@ def _compute_attention(
q_proj = self.scale_query(q_proj)
k_proj = self.scale_key(k_proj)

# TODO(dhwang2): splash attention supports GQA natively, so don't repeat with proper shard.
# https://github.com/jax-ml/jax/blob/7b9914d711593dca8725d46aa1dadb2194284519/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py#L934
# Repeats key/value heads dim if necessary.
k_proj = self._repeat_kv_heads(k_proj)
v_proj = self._repeat_kv_heads(v_proj)

# Constrain input to conform to partitioned MHA expectations.
q_proj = with_sharding_constraint(q_proj, cfg.mha_dim_to_partition_spec["btnh"])
k_proj = with_sharding_constraint(k_proj, cfg.mha_dim_to_partition_spec["bsnh"])
Expand Down
10 changes: 9 additions & 1 deletion axlearn/common/flash_attention/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from jax.experimental import mesh_utils
from jax.sharding import Mesh

Expand Down Expand Up @@ -91,6 +91,7 @@ def _prepare_layers(
sliding_window_size,
inference=False,
set_layer_bias_recursively=False,
tpu_block_size=512,
):
hidden_dim = num_heads * per_head_dim
kwargs = dict(
Expand All @@ -110,6 +111,7 @@ def _prepare_layers(
.set(
mha_dim_to_partition_spec=default_mha_dim_to_partition_spec(mesh_axis_names),
output_dim_to_partition_spec=default_output_dim_to_partition_spec(mesh_axis_names),
tpu_block_size=tpu_block_size,
)
)
if inference:
Expand Down Expand Up @@ -378,7 +380,9 @@ def test_forward(
mesh_axis_names=mesh_axis_names,
causal=causal,
sliding_window_size=sliding_window_size,
tpu_block_size=128,
)

# pylint: disable-next=protected-access
if test_layer._backend() == "gpu" and query_len_multiplier != 1:
pytest.skip(
Expand Down Expand Up @@ -734,3 +738,7 @@ def test_extend_step(
atol=2e-2,
)
jax.clear_backends()


if __name__ == "__main__":
absltest.main()
Loading
Loading