Skip to content

Commit

Permalink
Fix Gemma2 4d attention mask (#31674)
Browse files Browse the repository at this point in the history
Update modeling_gemma2.py

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
hiyouga and ArthurZucker committed Jun 28, 2024
1 parent 7edc993 commit 8691867
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,11 +629,13 @@ def forward(
if (
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
): # efficient SDPA and no padding
attention_mask = attention_mask * torch.tril(
torch.ones_like(attention_mask), diagonal=-self.sliding_window
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
if attention_mask.shape[1] <= 1: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]

residual = hidden_states

Expand Down

0 comments on commit 8691867

Please sign in to comment.