Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlexAttention slower than eager in HF transformers #95

Open
staghado opened this issue Dec 27, 2024 · 2 comments
Open

FlexAttention slower than eager in HF transformers #95

staghado opened this issue Dec 27, 2024 · 2 comments

Comments

@staghado
Copy link

related PR : huggingface/transformers#35423
Repro gist : https://gist.github.com/staghado/c3688a51aadec9e0b63316d8a7227064

The implementation combines a sliding window mask with a document mask. The masks are created once for each input and re-used for subsequent layers.
One thing that might be the issue is that the flex_attention function is not compiled in transformers.
I might be missing something, thanks in advance for your help.

Using attn_implementation=flex_attention
Sequence length : torch.Size([1, 702])
  Time taken: 0.7820 seconds

Using attn_implementation=sdpa
Sequence length : torch.Size([1, 702])
  Time taken: 0.0748 seconds

Using attn_implementation=eager
Sequence length : torch.Size([1, 702])
  Time taken: 0.0679 seconds
@staghado
Copy link
Author

when running FlexAttention vs. SDPA alone (with compile), I get :

Torch version: 2.6.0.dev20241112+cu121

=== Benchmark Results ===
+--------------+--------------+----------------------+----------------------+
|   Batch Size |   Seq Length |   FLEX Avg Time (ms) |   SDPA Avg Time (ms) |
+==============+==============+======================+======================+
|            1 |          128 |      19157.4         |      18323.5         |
+--------------+--------------+----------------------+----------------------+
|            1 |          256 |      29308.9         |      26515.9         |
+--------------+--------------+----------------------+----------------------+
|            1 |          512 |      42290.9         |      43449.5         |
+--------------+--------------+----------------------+----------------------+
|            1 |         1024 |      47303.1         |      85003.2         |
+--------------+--------------+----------------------+----------------------+
|            1 |         2048 |      89719.6         |     221348           |
+--------------+--------------+----------------------+----------------------+
|            1 |         4096 |     170735           |     842239           |
+--------------+--------------+----------------------+----------------------+
|            1 |         8192 |     331801           |          3.34551e+06 |
+--------------+--------------+----------------------+----------------------+
|            2 |         8192 |     645975           |          6.44274e+06 |
+--------------+--------------+----------------------+----------------------+
|            4 |         8192 |          1.26423e+06 |          1.24122e+07 |
+--------------+--------------+----------------------+----------------------+
|            4 |         8192 |          1.26883e+06 |          1.24035e+07 |
+--------------+--------------+----------------------+----------------------+
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
from tabulate import tabulate

from torch.nn.attention.flex_attention import (
    flex_attention,
    create_block_mask,
    create_mask,
)

torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False)
print(f"Torch version: {torch.__version__}")

from torch._inductor.utils import do_bench_using_profiling
from typing import Callable
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
    """Thin wrapper around do_bench_using_profiling"""
    no_args = lambda: func(*args, **kwargs)
    time = do_bench_using_profiling(no_args)
    return time * 1e3

benchmark_fn = benchmark_cuda_function_in_microseconds

WINDOW_SIZE = 64

def generate_block_mask(sequence_ids, cu_seqlens, WINDOW_SIZE):
    def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx):
        # only allow attention within the same sequence
        same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx]

        # get position within the sequence
        q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]]
        kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]]

        # sliding window within each sequence
        in_window = (q_pos - kv_pos).abs() <= WINDOW_SIZE

        return same_seq & in_window
    return sliding_window_seq_mask_mod

def SWA_mask(b, h, q_idx, kv_idx):
    # sliding window within each sequence
    in_window = (q_idx - kv_idx).abs() <= WINDOW_SIZE
    return in_window

# Benchmarking function
def run_benchmark(batch_sizes, sequence_lengths, num_heads=16, hidden_dim=64, n_runs=3):
    results = []

    for batch_size in batch_sizes:
        for seq_len in sequence_lengths:
            q = torch.randn(
                batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
            ).to("cuda")
            k = torch.randn(
                batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
            ).to("cuda")
            v = torch.randn(
                batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
            ).to("cuda")

            
            sequence_lengths = [seq_len] * batch_size
            sequence_ids = torch.cat([torch.full((length,), i, dtype=torch.long) for i, length in enumerate(sequence_lengths)]).to("cuda")
            _, counts = torch.unique_consecutive(sequence_ids, return_counts=True)
            cu_seqlens = torch.cat([torch.tensor([0], device=sequence_ids.device), counts.cumsum(0)[:]])

            block_mask = create_block_mask(
                generate_block_mask(sequence_ids, cu_seqlens, WINDOW_SIZE),
                B=None,
                H=None,
                Q_LEN=cu_seqlens[-1],
                KV_LEN=cu_seqlens[-1],
                device="cuda",
            )
            mask = create_mask(SWA_mask, None, None, seq_len, seq_len, device="cuda")

            # Benchmark flex_attention
            flex_times = []
            for _ in range(n_runs):
                flex_time = benchmark_fn(
                    flex_attention,
                    q.reshape(1, num_heads, -1, hidden_dim),
                    k.reshape(1, num_heads, -1, hidden_dim),
                    v.reshape(1, num_heads, -1, hidden_dim),
                    score_mod=None,
                    block_mask=block_mask,
                )
                flex_times.append(flex_time)
            flex_avg_time = (sum(flex_times) / n_runs) * 1000  # Convert to ms

            # Benchmark scaled_dot_product_attention with mask
            sdpa_times = []
            with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
                for _ in range(n_runs):
                    sdpa_time = benchmark_fn(
                        scaled_dot_product_attention,
                        q,
                        k,
                        v,
                        attn_mask=mask,
                    )
                    sdpa_times.append(sdpa_time)
                sdpa_avg_time = (sum(sdpa_times) / n_runs) * 1000  # Convert to ms

            results.append(
                {
                    "Batch Size": batch_size,
                    "Seq Length": seq_len,
                    "FLEX Avg Time (ms)": f"{flex_avg_time:.2f}",
                    "SDPA Avg Time (ms)": f"{sdpa_avg_time:.2f}",
                }
            )

    return results


if __name__ == "__main__":
    batch_sizes = [
        1,
        2,
        4,
    ]
    sequence_lengths = [128, 256, 512, 1024, 2048, 4096, 8192]
    n_runs = 5

    results = run_benchmark(batch_sizes, sequence_lengths, n_runs=n_runs)

    # Generate table
    print("\n=== Benchmark Results ===")
    print(tabulate(results, headers="keys", tablefmt="grid"))

So my question is how to cleanly integrate

torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False)

into transformers?

@drisspg
Copy link
Contributor

drisspg commented Dec 27, 2024

Without a doubt it will be slower than eager when it is not compiled. Let me ping some HF folks to see if we can raise a warning / ensure it is easy to compile.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants