Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add support for flex attention (paged attention) #35419

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
83 changes: 83 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2120,3 +2120,86 @@ def _prefetch_layer_in_context(self, layer_idx: int) -> None:

self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True)
self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True)


class PagedAttentionCache(Cache):
def __init__(
self, config, batch_size, max_cache_len, device, dtype, layer_device_map, n_pages=None, page_size=128
):
super().__init__()
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self.paged_attentions = []
self.key_cache = []
self.value_cache = []
self.page_size = page_size
if n_pages is not None:
self.n_pages = n_pages
else:
self.n_pages = (max_cache_len + page_size - 1) // page_size * batch_size
KV_H = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads
QK_D = config.hidden_size // config.num_attention_heads
V_D = QK_D
from torch.nn.attention.experimental._paged_attention import PagedAttention
from torch.nn.attention.flex_attention import create_block_mask, noop_mask

for i in range(config.num_hidden_layers):
max_cached_seq_len = self.n_pages * self.page_size
self.paged_attentions.append(PagedAttention(self.n_pages, self.page_size, batch_size, device=device))
self.key_cache.append(torch.zeros(1, KV_H, max_cached_seq_len, QK_D, device=device, dtype=dtype))
self.value_cache.append(torch.zeros(1, KV_H, max_cached_seq_len, V_D, device=device, dtype=dtype))
self.batch_reserve(self.paged_attentions[i], torch.tensor([max_cache_len for _ in range(batch_size)]))
self.batch_size = batch_size
self.max_cache_len = max_cache_len
block_mask = create_block_mask(noop_mask, batch_size, 1, 1, max_cache_len, device=device, BLOCK_SIZE=page_size)
self.block_mask = self.paged_attentions[0].convert_logical_block_mask(block_mask)
self.score_mods = []
self.score_mods.append(None)
self.score_mods.append(None)

def reset(self) -> None:
"""Resets the cache values while preserving the objects."""

self._seen_tokens = 0

# Zero out cache.
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address.
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

def batch_reserve(self, paged_attention, target_seq_len):
(B,) = target_seq_len.shape
for b in range(B):
paged_attention.reserve(
torch.tensor(b),
target_seq_len[b],
)

def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
# update seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
KV_B, KV_H, KV_S, QK_D = key_states.shape
device = key_states.device
batch_idx = torch.arange(KV_B, device=device, dtype=torch.int32)
self.paged_attentions[layer_idx].assign(
batch_idx,
cache_kwargs["cache_position"],
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
)
return self.key_cache[layer_idx], self.value_cache[layer_idx]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1

def reorder_cache(self, beam_idx):
for layer_idx in range(len(self.paged_attentions)):
if self.key_cache[layer_idx] != []:
page_table = self.paged_attentions[layer_idx].page_table.clone()
for batch_idx, target_batch_idx in enumerate(beam_idx.tolist()):
page_table[batch_idx] = self.paged_attentions[layer_idx].page_table[target_batch_idx]
self.paged_attentions[layer_idx].page_table = page_table
2 changes: 2 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
HybridCache,
MambaCache,
OffloadedStaticCache,
PagedAttentionCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
Expand All @@ -70,6 +71,7 @@
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
"mamba": MambaCache,
"paged": PagedAttentionCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys())
Expand Down
18 changes: 12 additions & 6 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
import copy
import inspect
import time
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -1650,6 +1651,11 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None):
"dtype": cache_dtype,
"layer_device_map": layer_device_map,
}
if cache_implementation == "paged":
if hasattr(self.config, "n_pages"):
cache_kwargs["n_pages"] = self.config.n_pages
if hasattr(self.config, "page_size"):
cache_kwargs["page_size"] = self.config.page_size
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
Expand Down Expand Up @@ -2228,7 +2234,7 @@ def generate(
)

# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
result = self._sample(
result, latency_list = self._sample(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
Expand Down Expand Up @@ -2397,7 +2403,7 @@ def typeerror():
should_convert_cache = True
if should_convert_cache:
result.past_key_values = result.past_key_values.to_legacy_cache()
return result
return result, latency_list

def _has_unfinished_sequences(
self,
Expand Down Expand Up @@ -3178,6 +3184,7 @@ def _sample(
`model.config.is_encoder_decoder=True`.
"""
# init values
latency_list = []
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
Expand Down Expand Up @@ -3211,16 +3218,14 @@ def _sample(
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
tic = time.time()
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

# forward pass to get next token
outputs = self(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
Expand Down Expand Up @@ -3282,6 +3287,7 @@ def _sample(
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del outputs
latency_list.append(time.time() - tic)

if streamer is not None:
streamer.end()
Expand Down Expand Up @@ -3309,7 +3315,7 @@ def _sample(
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
return input_ids, latency_list

def _temporary_reorder_cache(self, past_key_values, beam_idx):
"""
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,12 +1505,18 @@ def _autoset_attn_implementation(
"eager",
"sdpa",
"flash_attention_2",
"flex_attention",
"paged_attention",
]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._support_flex_attn:
message += ', `"attn_implementation=flex_attention"` (implementation using flex attention)'
if cls._support_paged_attn:
message += ', `"attn_implementation=paged_attention"` (implementation using paged attention)'
raise ValueError(message + ".")

# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
Expand Down Expand Up @@ -1561,6 +1567,8 @@ def _autoset_attn_implementation(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif requested_attn_implementation in ["flex_attention", "paged_attention"]:
return config
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
else:
Expand Down
Loading