[HN Gopher] FlexAttention: The Flexibility of PyTorch with the P...
       ___________________________________________________________________
        
       FlexAttention: The Flexibility of PyTorch with the Performance of
       FlashAttention
        
       Author : limoce
       Score  : 193 points
       Date   : 2024-08-08 07:24 UTC (15 hours ago)
        
 (HTM) web link (pytorch.org)
 (TXT) w3m dump (pytorch.org)
        
       | andy12_ wrote:
       | This is so cool. I want to try to implement something with this
       | right now.
        
       | gchamonlive wrote:
       | Always had the curiosity to put something together with pytorch
       | but it always seemed either a steep learning curve or there
       | wasn't a big motivator (project, problem to solve, something in
       | my daily routine to optimize).
       | 
       | Does anybody have a good starting point to learn with hands-on
       | projects and also that could accommodate for flexattention?
        
         | ryneandal wrote:
         | IMO the PyTorch getting started tutorials are really good
         | (https://pytorch.org/tutorials/beginner/basics/intro.html).
         | 
         | A classifier for handwritten digits in the MNIST dataset is
         | generally considered the "Hello World" of neural networks. I
         | went over it in a course, but there are countless tutorials to
         | be found online, i.e.
         | https://www.digitalocean.com/community/tutorials/introductio...
         | 
         | Once you begin to understand how to handle data and how to
         | define layers, you can start playing around with whatever your
         | heart desires. The rabbit hole is vast and endless :)
        
         | jisaacso wrote:
         | Agreed that PyTorch tutorials are a great place to start.
         | Specific to flexattention, the blog references the accompanying
         | attention gym, which has a series of examples of how to use
         | flex: https://github.com/pytorch-labs/attention-gym/
        
         | sva_ wrote:
         | Check Out Kaggle for the challenges
        
       | brrrrrm wrote:
       | For most LLM workloads today (short text chats), hundreds or a
       | couple thousand tokens suffice. attention mechanisms don't
       | dominate (< 30% compute). But as the modalities inevitably grow,
       | work in attention approximation/compression is going to be
       | paramount.
       | 
       | Nice to see Pytorch already elegantly supporting this next step
       | in research
        
       | visarga wrote:
       | It's interesting that optimizing a computation that can be
       | described in a single line of math takes so much work. It took
       | forever even to discover Flash attention. And in the 6 years
       | since transformers were invented, thousands of papers worked on
       | making it faster.
       | 
       | Attention(Q,K,V) = Softmax(Q*K^T/sqrt(d_k))*V
       | 
       | FlexAttention seems to have found the right abstraction for the
       | task.
        
         | d3m0t3p wrote:
         | Yea, because the math have stripped down the whole thing to : I
         | have data I do operation on them. while in reality we deal with
         | multi head attention / grouped query and the positional
         | encoding.
         | 
         | That's all without taking into account the broadcasting done on
         | the batch dimension
        
           | chillee wrote:
           | I would agree with this. For example, how would you represent
           | causal attention in the standard equation?
        
         | brrrrrm wrote:
         | this is true of even just matrix multiplication (A*B) of which
         | attention has two
        
       | chillee wrote:
       | Hi, one of the authors of this blog post (Horace He), along with
       | Driss Guessous, Yanbo Liang, and Joy Dong.
       | 
       | We're quite happy with this abstraction - happy to answer any
       | questions about it!
        
         | zaptrem wrote:
         | For those of us using the 2D NATTEN kernel from their library
         | along with torch.compile, is this faster? Especially given all
         | their tricks (e.g., the non-deterministic KV-parallelism)
        
           | chillee wrote:
           | In my (very amateurish) testing, I think the performance
           | seemed pretty comparable (for non-dilated natten). I need to
           | do some proper benchmarking though!
        
       | barrenko wrote:
       | Can someone do a short summary or TL;DR for this?
        
         | chillee wrote:
         | https://x.com/chhillee/status/1821253769147118004?s=46
         | 
         | Perhaps this tweet thread would be better.
        
           | sva_ wrote:
           | https://nitter.poast.org/chhillee/status/1821253769147118004
        
             | barrenko wrote:
             | Thanks, just weaned myself of Twitter / X.
        
       | alecco wrote:
       | > FlexAttention achieves 90% of FlashAttention2's performance in
       | the forward pass and 85% in the backward pass.
       | 
       | It's very good. But note FlashAttention-3 is 1.5x - 2x faster
       | than FlashAttention-2.
        
         | chillee wrote:
         | These benchmarks are on Ampere, where FA3 has no performance
         | benefits over FA2.
         | 
         | On Hopper, FlexAttention is currently about 80% of
         | FlashAttention3's performance (about 500 TFLOPs peak)
        
       ___________________________________________________________________
       (page generated 2024-08-08 23:00 UTC)