30%+ Speedup for AMD RDNA3/ROCm using Flash Attention w/ SDP Fallback #7172
Replies: 4 comments 9 replies
-
cc @sayakpaul |
Beta Was this translation helpful? Give feedback.
-
Apparently it's possible for card names to not have "AMD" so I added another check for "Radeon" on the hijack and made it prettier |
Beta Was this translation helpful? Give feedback.
-
Hi there. I am very impressed with performance of my 7800xt with quickdif and the memory optimization you've done. Is there a way to bring this optimization to auto1111 webui? I mean the flash attention part. Regardless, appreciate your contribution to the AMD community! Created this account to share this comment :) |
Beta Was this translation helpful? Give feedback.
-
Hi. As you said, the official FA2 support on ROCM is released I guess via Pytorch 2.3. Now, what is its significance on consumer cards like my 7800XT? It broke my existing FA build when I installed pytorch 2.3. |
Beta Was this translation helpful? Give feedback.
-
Yes, now you too can have memory efficient attention on AMD with some (many) caveats.
Numbers
Throughput for the diffusers default (SDP), my SubQuad port, and the presented Flash Attention + SDP fallback method
All numbers measured with a 7900 XTX on Pytorch Nightly + ROCm 6.0. Additionally, I have my card limited to 300w so your numbers may be higher.
Okay, how?
First you need to install Flash Attention, which requires the ROCm SDK to be installed and ideally your PyTorch version should match. If your distro is on ROCm 6.0, use the nightly Torch to match. Using this setup I've never had install issues across both ROCm 5.7 and 6.0
To install flash attention, activate your virtual environment (if you use one, which you should) and execute
Which will install AMD's Flash Attention 2 fork with Navi support. There's a very real chance it may only work for 7000 series GPUs , as older cards don't have WMMAs and I'm not sure this build has any fallbacks for that.
Then, to actually use flash attention in Diffusers you need to implement it in an attention processor and have a fallback for unsupported head dimensions. Which I've already done here
To use you can simply place the
flash_attn_rocm
file in your tree and importFlashAttnProcessor
such asAnd now inference should be much faster, use less memory (usually), and more.
You mentioned caveats?
Oh there's plenty. Read about them here but I'll summarize
Firstly, the reason I keep mentioning "SDP Fallback" is the Navi branch currently does not support head dimensions > 128. Here I chose to fall back to good ole SDP for that since the functions are basically the same minus a few transposes.
The 128 head dim limit results in memory spikes when it falls back to SDP, particularly on the VAE. This means for large renders you'll probably have to use VAE tiling (SubQuad might work too, but it'll be slow).
Second, this is forward pass only. No training. Like at all.
Finally, there's no masking support in the function. So far it seems to run ok, but this limitation might adversely affect some workflows.
Also, the AMD fork is like a billion versions behind the Dao-AILab master, so newer functions aren't available either. 2.0.4 is all we get. On top of this, it appears to be very unoptimized. It barely works but that's a lot better than the last year of nothing.
What about other models?
A lot of models will use SDPA and dont contain their own easy way to set attention. What I'd recommend instead is to simply monkey patch the torch sdpa function with your own that hijacks it into Flash Attention where supported. Example:
Enjoy that Stable Cascade speedup.
Beta Was this translation helpful? Give feedback.
All reactions