[HN Gopher] FlashAttention-3: Fast and Accurate Attention with A...
       ___________________________________________________________________
        
       FlashAttention-3: Fast and Accurate Attention with Asynchrony and
       Low-Precision
        
       Author : jhshah
       Score  : 160 points
       Date   : 2024-07-11 17:06 UTC (5 hours ago)
        
 (HTM) web link (www.together.ai)
 (TXT) w3m dump (www.together.ai)
        
       | edude03 wrote:
       | How much is the flash attention algorithm tied to the hardware?
       | For example, in this announcement they mention taking advantage
       | of the async capabilities of the H100 GPUs which I assume means
       | you won't get those speedups on non H series card. Two, the
       | actual flash attention library requires CUDA, although the
       | algorithm has apparently?[^0] been ported to metal. I would
       | imagine if the algorithm was literally just a pure function it
       | could be implemented for any GPU/ML framework?
       | 
       | [0]: https://github.com/philipturner/metal-flash-attention
        
         | kristjansson wrote:
         | FlashAttention's algorithmic improvements is mostly just
         | splitting/combining the softmax part of attention, and is
         | itself not totally novel. The overwhelming contribution is
         | implementing that, and all its fiddly pieces, efficiently on
         | Nvidia hardware.
        
           | namibj wrote:
           | To clarify further, flash attention is explicitly targeting a
           | compute engine with separate MMA and "scalar" vector
           | execution units that allow post-processing the MMA outputs
           | without involving memory bandwidth (though arithmetic
           | intensity, especially relative between the MMA and the
           | "scalar" instructions, is of concern), with a substantial
           | amount of manually-managed L1D$ to use as sub-matrix
           | accumulator, and a linear-in-context-length amount of "VRAM"
           | that requires sensible arithmetic intensity to avoid being a
           | bandwidth bottleneck (iirc in the hundreds when counting the
           | scalar multiplies hiding in the MMA instructions).
           | 
           | This v3 with async might for once be so tied to Hopper that
           | it's not trivially portable to another platform that has the
           | mentioned hardware blocks (AFAIK every AMD GCN card that can
           | do compute shaders would qualify, though they do lack a
           | specialized MMA unit).
        
           | refulgentis wrote:
           | Clarifying:
           | 
           | Given the question: "How much is the flash attention
           | algorithm tied to the hardware?"
           | 
           | The answer is 0.
           | 
           | ex. you can find generic flash attention recently added in
           | llama.cpp and ONNX (MS needed it for Phi-3, needed for
           | Recall).
           | 
           | On the side, novelty, I have no direct knowledge on, IMHO,
           | asking that question would devolve the way novelty arguments
           | do in any field: there's always someone else who can claim
           | they did 80% of $X via $X-1, therefore, $X is by and large
           | not novel. Ad infinitum.
        
             | kristjansson wrote:
             | I think the right analogy for FA is high-quality cache-
             | aware BLAS kernel implementations. The algorithm(s) is
             | (are) clever and (as you note) completely independent of
             | hardware. However, a hardware-naive implementation is
             | approximately worthless. Most of the value of MKL, or
             | Accelerate, or FA is in the careful matching of the
             | parameters and implementation of the algorithm to the
             | capabilities of hardware it's going run on.
             | 
             | I definitely don't mean to take away from Tri/FA by
             | mentioning novelty - I'm just repeating from paper, which
             | refers back to algebraic aggregates[0] in its discussion of
             | their tiled softmax.
             | 
             | [0]: https://web.stanford.edu/class/cs345d-01/rl/olap.pdf
        
         | f_devd wrote:
         | > How much is the flash attention algorithm tied to the
         | hardware?
         | 
         | The original FA, almost none.
         | 
         | For the latest versions depends on your abstraction,
         | ThunderKittens[0] provides about the same speed up over FA2
         | (1.3x-2x%) as the article but relatively universal across GPUs.
         | For any new hardware there may be hardware specific features
         | that make it edge out more performance; usually vendors will
         | adopt any new features that seems to beat them, but you do get
         | fragmented API/libraries (which is already true for CUDA).
         | 
         | [0]: https://hazyresearch.stanford.edu/blog/2024-05-12-tk
        
           | kristjansson wrote:
           | I mean they're building an API to abstract away some of the
           | SKU-to-SKU differences, but the broader point cuts the other
           | way, I think:
           | 
           | > In fact, more broadly we believe we should really reorient
           | our ideas of AI around what maps well onto the hardware. How
           | big should a recurrent state be? As big can fit onto an SM.
           | How dense should the compute be? No less so than what the
           | hardware demands. An important future direction of this work
           | for us is to use our learnings about the hardware to help us
           | design the AI to match.
           | 
           | The value is in adapting the implementation (either manually
           | at write-time or programmatically at run-time) to the
           | specifics of the hardware.
           | 
           | Also, great line:
           | 
           | > And we ask: if your matrix multiply is smaller than 16x16,
           | are you sure what you're doing is AI?
        
         | vhiremath4 wrote:
         | There are a bunch of good answers, but I wanted to succinctly
         | say "practically, quite a bit". Here's a good little rabbit-
         | hole example:
         | 
         | > https://github.com/karpathy/nanoGPT/blob/master/model.py#L45
         | 
         | Karpathy's nanoGPT calling flash attention by checking if
         | torch.nn.functional.scaled_dot_product_attention exists
         | 
         | >
         | https://pytorch.org/docs/stable/generated/torch.nn.functiona...
         | 
         | Looking at the docs, in reality, most of the time you want this
         | to call out to FA2 which optimizes the kernals on the device to
         | split ops on the Softmax of the triangular matrix as well as
         | reduce moving unnecessary batches of floating point numbers
         | back and forth from the GPU to the CPU.
         | 
         | > https://arxiv.org/pdf/2307.08691
         | 
         | The paper for FA2 almost entirely considers itself through the
         | hardware it's running on.
        
         | 3abiton wrote:
         | To add to the discussion, from a practical perspective, AMD
         | hardware totally sucks and yet to have proper implementation
         | with flash-attention-2. ROCm is moving to usable slowly, but
         | not close to being even comparable with cuda.
        
         | slashdave wrote:
         | Conceptually, just a bit, practically (in terms of
         | implementation), a lot. The standard python implementation
         | internally compiles a kernel for your specific hardware.
        
       | lxe wrote:
       | > FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).
       | 
       | How does FA3 fare for consumer GPUs such as 3090 and 4090?
        
         | apsec112 wrote:
         | It's Hopper-specific, the improvements are closely tied to
         | Hopper features like warp groups and TMA. For 4090s, you might
         | get a speedup by using the Triton implementation of FP8
         | attention: https://triton-lang.org/main/getting-
         | started/tutorials/06-fu...
        
           | moffkalast wrote:
           | The original flash attention (v1?) took like a year to get
           | added to llama.cpp and only provides single digit percent
           | VRAM savings for typical context lengths and practically no
           | speed boost. Still nice to have, but man was this thing
           | overhyped. I doubt v3 will do more than marginally better on
           | the RTX 5000 series.
        
             | apsec112 wrote:
             | On GPU, or on CPU/Metal? For the latter I'm not surprised,
             | but that's because they have a totally different
             | memory/cache hierarchy.
        
               | moffkalast wrote:
               | With CUDA offloading, I don't think it runs otherwise at
               | all.
        
       | Der_Einzige wrote:
       | This is one of the most important improvements in all of AI,
       | because it benefits most AI users by giving them access to more,
       | faster, for the same hardware with little to no tradeoffs.
        
         | snovv_crash wrote:
         | ...for all those users with H100s.
        
           | rfoo wrote:
           | ... which is currently the most cost-efficient and
           | environment-friendly way to do LLM inference [0].
           | 
           | [0] Small footprint time: before B100 ships; for actually
           | large language models; for prefill only; may cause cancer in
           | California.
        
       | andy_xor_andrew wrote:
       | hoping an expert can answer a few Qs I have :)
       | 
       | Is FlashAttention simply a drop-in replacement for the attention
       | operation in an LLM? Can it be used anywhere that an "attention"
       | operation is used? Or does a LLM need to be trained specially to
       | use FA?
       | 
       | How does FA relate to attention _strategies_ like GQA (grouped
       | query attention) or sliding-window attention? Are they orthogonal
       | concepts? Or you need a specific FA implementation for each
       | strategy?
       | 
       | Recently llama.cpp added flash attention support - does this just
       | mean they started consuming a flash attention-provided CUDA
       | kernel or something?
       | 
       | lastly, in this post, they compare FlashAttention to Triton. I
       | thought Triton was like an abstraction layer? Couldn't FA be
       | implemented in Triton? I just don't really get what it means to
       | say "FlashAttention vs. Triton".
        
         | zaptrem wrote:
         | > Is FlashAttention simply a drop-in replacement for the
         | attention operation in an LLM? Can it be used anywhere that an
         | "attention" operation is used? Or does a LLM need to be trained
         | specially to use FA?
         | 
         | Yes
         | 
         | > How does FA relate to attention strategies like GQA (grouped
         | query attention) or sliding-window attention? Are they
         | orthogonal concepts? Or you need a specific FA implementation
         | for each strategy?
         | 
         | Flash Attention is a way of calculating the Softmax(QK^T)V part
         | of attention, whereas GQA is a way of calculating the Q, K, and
         | V matricies. Sliding window attention (less sure about this,
         | there are a bunch of windowed attention techniques) change the
         | attention mask (the thing that controls which queries can
         | attend to which keys).
         | 
         | > Recently llama.cpp added flash attention support - does this
         | just mean they started consuming a flash attention-provided
         | CUDA kernel or something?
         | 
         | I don't use llama.cpp but that sounds about right.
         | 
         | > lastly, in this post, they compare FlashAttention to Triton.
         | I thought Triton was like an abstraction layer? Couldn't FA be
         | implemented in Triton? I just don't really get what it means to
         | say "FlashAttention vs. Triton".
         | 
         | They're talking about a previous Flash Attention implementation
         | written in Triton.
        
         | apsec112 wrote:
         | 1) Pretty much, it's mathematically equivalent. The only
         | software issues are things like managing dependency versions
         | and data formats in-memory, but Flash Attention 2 is already
         | built into HuggingFace and other popular libraries. Flash
         | Attention 3 probably will be soon, although it requires an H100
         | GPU to run
         | 
         | 2) Flash Attention 2 added support for GQA in past version
         | updates:
         | 
         | https://github.com/Dao-AILab/flash-attention
         | 
         | 3) They're comparing this implementation of Flash Attention
         | (which is written in raw CUDA C++) to the Triton implementation
         | of a similar algorithm (which is written in Triton):
         | https://triton-lang.org/main/getting-started/tutorials/06-fu...
        
       | localfirst wrote:
       | spoiler: $xxx,xxx hardware required to run
        
         | sva_ wrote:
         | $25k-$30k
        
         | aabhay wrote:
         | If you need to run it continuously for a year
        
       | WanderPanda wrote:
       | Compiler folks: Is there any chance compilers will be able to
       | find optimizations like FlashAttention on their own? Seems like
       | TVM and tinygrad are working in that direction but I find it hard
       | to believe that that would be feasible
        
         | rfoo wrote:
         | No. Think of it like a different algorithm. You just take the
         | shape of the hardware into consideration when designing the
         | algorithm instead of considering math only.
         | 
         | > Seems like TVM
         | 
         | Fair enough, though technically they are still about different
         | things but it's indeed very close, but
         | 
         | > and tinygrad
         | 
         | ?????? what gives you this impression?
        
           | dauertewigkeit wrote:
           | What's the distinction between what TVM does and
           | FlashAttention type optimizations?
        
             | rfoo wrote:
             | There is more than layout / tile schedule in FA. For
             | example, first, to be able to fuse all these together [0]
             | at all, you need to "decompose" the softmax to make it
             | combinable, which requires maintaining some extra
             | statistics. Won't gonna repeat the math here as the
             | original FA paper is already very clear.
             | 
             | [0] so you can avoid materializing intermediate matrices
             | and still being able to compute in blocks.
        
         | namibj wrote:
         | In theory, yes, it's "just" some algebraic properties of the
         | math used that allow for substantial reordering, and then you'd
         | add fairly regular polyhedral loop tiling. Just expensive to
         | do, so you'll have to cache the effort.
         | 
         | The area of e-graph optimizers seems well-suited to this, btw.
         | It's not really deployed outside of some niche tooling though,
         | as it's a big paradigm shift in optimizer pass handling (e.g.,
         | doesn't work well with chairs classic call graphs, so control
         | flow needs to be massively revamped to deploy e-graphs
         | outside/across basic blocks and for loops (break and return not
         | supported!)).
        
         | Lerc wrote:
         | This strikes me as an extremely difficult but not intractable
         | problem.
         | 
         | I'm not sure what the state of the art in compiler optimisation
         | is with regard to data positioning and targeting maximum
         | processor usage
         | 
         | There was a video on optimisation a while back that showed
         | small optimisations caused increases in speed that were
         | insignificant when compared to the speed variance induced by
         | the memory layout that the optimisation (or even a random
         | change) caused.
         | 
         | While that talk was more focused on getting a signal past the
         | noise. That noise itself is an artifact of compilers being not
         | particularly good at handling a much simpler form of the
         | problem you describe.
         | 
         | CPU and memory architectures are complex when caches and access
         | patterns impact upon speed.
         | 
         | When you add in GPU architectures to the mix I think you might
         | be in fairly uncharted territory.
         | 
         | Maybe one day.
         | 
         | Of course since we are in the field of AI there is also the
         | question of could a sufficiently smart AI do this. It depends
         | on the value of sufficient.
         | 
         | I would like to think that an extremely high level test for an
         | AI model could be to give it something like micrograd and tell
         | it to produce something with the same interface that
         | outperforms torch.
         | 
         | We're not even in the ballpark of being able to do that yet,
         | but it will be interesting when and if that happens.
        
       | ex3ndr wrote:
       | I am wondering why flash attention is like 5x slower with
       | variable masking than without it? Lack of good masking support
       | almost zeros out the optimizations
        
       ___________________________________________________________________
       (page generated 2024-07-11 23:00 UTC)