Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModernBERT FlexAttention #35423

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 103 additions & 7 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
is_flash_attn_2_available,
logging,
)
from ...utils.import_utils import is_triton_available
from ...utils.import_utils import is_torch_flex_attn_available, is_triton_available
from .configuration_modernbert import ModernBertConfig


Expand All @@ -49,6 +49,12 @@
else:
RotaryEmbedding = object

# NOTE : the ModernBERT flexattention implementation is not compatible with torch < 2.6
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
else:
BlockMask, create_block_mask, flex_attention = object, object, object

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base"
Expand Down Expand Up @@ -411,10 +417,44 @@ def sdpa_attention_forward(
return (attn_output,)


def flex_attention_forward(
module: "ModernBertAttention",
qkv: torch.Tensor,
rotary_emb: ModernBertUnpaddedRotaryEmbedding,
position_ids: Optional[torch.LongTensor],
block_mask: "BlockMask",
max_seqlen: int,
bs: int,
dim: int,
**_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# add dummy batch dimension -> [batch_size=1, total_nnz, 3, nheads, headdim]
qkv = qkv.unsqueeze(0)
cos, sin = rotary_emb(qkv, position_ids=position_ids)
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
# query, key, value: [batch_size, nheads, total_nnz, head_dim]
query, key = apply_rotary_pos_emb(query, key, cos, sin)

attn_output = flex_attention(
query,
key,
value,
score_mod=None,
block_mask=block_mask,
enable_gqa=False,
scale=None,
return_lse=False,
)

attn_output = attn_output.squeeze(0).transpose(0, 1).contiguous()
return (attn_output.view(bs, dim),)


MODERNBERT_ATTENTION_FUNCTION = {
"flash_attention_2": flash_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
"flex_attention": flex_attention_forward,
}


Expand Down Expand Up @@ -479,7 +519,7 @@ def forward(
qkv = self.Wqkv(hidden_states)

bs = hidden_states.shape[0]
if self.config._attn_implementation == "flash_attention_2":
if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]:
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
else:
qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
Expand Down Expand Up @@ -523,6 +563,7 @@ def forward(
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
block_mask: Optional["BlockMask"] = None,
max_seqlen: Optional[int] = None,
output_attentions: Optional[bool] = False,
) -> torch.Tensor:
Expand All @@ -532,6 +573,7 @@ def forward(
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
block_mask=block_mask,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -574,7 +616,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
_no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = False
_supports_flex_attn = True

def _init_weights(self, module: nn.Module):
cutoff_factor = self.config.initializer_cutoff_factor
Expand Down Expand Up @@ -837,6 +879,40 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embeddings.tok_embeddings = value

@classmethod
def offsets_to_sequence_ids_tensor(cls, offsets):
counts = offsets[1:] - offsets[:-1]
return torch.repeat_interleave(torch.arange(len(counts), device=offsets.device, dtype=torch.int32), counts)

def create_attention_mask(self, sequence_ids, cu_seqlens, window_size):
"""
Creates a block mask combining sequence masking and local/or global attention masking.
"""

def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx):
# only allow attention within the same sequence
same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx]

# get position within the sequence
q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]]
kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]]

# sliding window within each sequence
in_window = (q_pos - kv_pos).abs() <= window_size

return same_seq & in_window

total_nnz = cu_seqlens[-1]
block_mask = create_block_mask(
sliding_window_seq_mask_mod,
B=None,
H=None,
Q_LEN=total_nnz,
KV_LEN=total_nnz,
device=sequence_ids.device,
)
return block_mask

@add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
Expand Down Expand Up @@ -877,13 +953,15 @@ def forward(
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)

repad = False
if self.config._attn_implementation == "flash_attention_2":
if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]:
if indices is None and cu_seqlens is None and max_seqlen is None:
repad = True
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
)
if self.config._attn_implementation == "flex_attention":
position_ids = torch.arange(cu_seqlens[-1], device=cu_seqlens.device).unsqueeze(0)
else:
if position_ids is None:
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
Expand All @@ -894,7 +972,23 @@ def forward(

hidden_states = self.embeddings(input_ids)

for encoder_layer in self.layers:
# create block mask
if self.config._attn_implementation == "flex_attention":
sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens)
local_window_size = self.config.local_attention // 2
_cached_local_mask = self.create_attention_mask(sequence_ids, cu_seqlens, local_window_size)
global_window_size = max_seqlen
_cached_global_mask = self.create_attention_mask(sequence_ids, cu_seqlens, global_window_size)
else:
block_mask = None

for layer_id, encoder_layer in enumerate(self.layers):
if self.config._attn_implementation == "flex_attention":
if layer_id % self.config.global_attn_every_n_layers == 0:
block_mask = _cached_global_mask
else:
block_mask = _cached_local_mask

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

Expand All @@ -906,6 +1000,7 @@ def forward(
sliding_window_mask,
position_ids,
cu_seqlens,
block_mask,
max_seqlen,
output_attentions,
)
Expand All @@ -916,6 +1011,7 @@ def forward(
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
block_mask=block_mask,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1046,7 +1142,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self._maybe_set_compile()

if self.config._attn_implementation == "flash_attention_2":
if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]:
if indices is None and cu_seqlens is None and max_seqlen is None:
batch_size, seq_len = input_ids.shape[:2]
if attention_mask is None:
Expand Down Expand Up @@ -1092,7 +1188,7 @@ def forward(
if labels is not None:
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)

if self.config._attn_implementation == "flash_attention_2":
if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]:
with torch.no_grad():
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
if not return_dict:
Expand Down
Loading