-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FlexAttention slower than eager in HF transformers #95
Comments
when running FlexAttention vs. SDPA alone (with compile), I get :
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
from tabulate import tabulate
from torch.nn.attention.flex_attention import (
flex_attention,
create_block_mask,
create_mask,
)
torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False)
print(f"Torch version: {torch.__version__}")
from torch._inductor.utils import do_bench_using_profiling
from typing import Callable
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3
benchmark_fn = benchmark_cuda_function_in_microseconds
WINDOW_SIZE = 64
def generate_block_mask(sequence_ids, cu_seqlens, WINDOW_SIZE):
def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx):
# only allow attention within the same sequence
same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx]
# get position within the sequence
q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]]
kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]]
# sliding window within each sequence
in_window = (q_pos - kv_pos).abs() <= WINDOW_SIZE
return same_seq & in_window
return sliding_window_seq_mask_mod
def SWA_mask(b, h, q_idx, kv_idx):
# sliding window within each sequence
in_window = (q_idx - kv_idx).abs() <= WINDOW_SIZE
return in_window
# Benchmarking function
def run_benchmark(batch_sizes, sequence_lengths, num_heads=16, hidden_dim=64, n_runs=3):
results = []
for batch_size in batch_sizes:
for seq_len in sequence_lengths:
q = torch.randn(
batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
).to("cuda")
k = torch.randn(
batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
).to("cuda")
v = torch.randn(
batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
).to("cuda")
sequence_lengths = [seq_len] * batch_size
sequence_ids = torch.cat([torch.full((length,), i, dtype=torch.long) for i, length in enumerate(sequence_lengths)]).to("cuda")
_, counts = torch.unique_consecutive(sequence_ids, return_counts=True)
cu_seqlens = torch.cat([torch.tensor([0], device=sequence_ids.device), counts.cumsum(0)[:]])
block_mask = create_block_mask(
generate_block_mask(sequence_ids, cu_seqlens, WINDOW_SIZE),
B=None,
H=None,
Q_LEN=cu_seqlens[-1],
KV_LEN=cu_seqlens[-1],
device="cuda",
)
mask = create_mask(SWA_mask, None, None, seq_len, seq_len, device="cuda")
# Benchmark flex_attention
flex_times = []
for _ in range(n_runs):
flex_time = benchmark_fn(
flex_attention,
q.reshape(1, num_heads, -1, hidden_dim),
k.reshape(1, num_heads, -1, hidden_dim),
v.reshape(1, num_heads, -1, hidden_dim),
score_mod=None,
block_mask=block_mask,
)
flex_times.append(flex_time)
flex_avg_time = (sum(flex_times) / n_runs) * 1000 # Convert to ms
# Benchmark scaled_dot_product_attention with mask
sdpa_times = []
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
for _ in range(n_runs):
sdpa_time = benchmark_fn(
scaled_dot_product_attention,
q,
k,
v,
attn_mask=mask,
)
sdpa_times.append(sdpa_time)
sdpa_avg_time = (sum(sdpa_times) / n_runs) * 1000 # Convert to ms
results.append(
{
"Batch Size": batch_size,
"Seq Length": seq_len,
"FLEX Avg Time (ms)": f"{flex_avg_time:.2f}",
"SDPA Avg Time (ms)": f"{sdpa_avg_time:.2f}",
}
)
return results
if __name__ == "__main__":
batch_sizes = [
1,
2,
4,
]
sequence_lengths = [128, 256, 512, 1024, 2048, 4096, 8192]
n_runs = 5
results = run_benchmark(batch_sizes, sequence_lengths, n_runs=n_runs)
# Generate table
print("\n=== Benchmark Results ===")
print(tabulate(results, headers="keys", tablefmt="grid")) So my question is how to cleanly integrate torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False) into transformers? |
Without a doubt it will be slower than eager when it is not compiled. Let me ping some HF folks to see if we can raise a warning / ensure it is easy to compile. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
related PR : huggingface/transformers#35423
Repro gist : https://gist.github.com/staghado/c3688a51aadec9e0b63316d8a7227064
The implementation combines a sliding window mask with a document mask. The masks are created once for each input and re-used for subsequent layers.
One thing that might be the issue is that the flex_attention function is not compiled in transformers.
I might be missing something, thanks in advance for your help.
The text was updated successfully, but these errors were encountered: