[HN Gopher] FlexAttention: The Flexibility of PyTorch with the P...
___________________________________________________________________
FlexAttention: The Flexibility of PyTorch with the Performance of
FlashAttention
Author : limoce
Score : 193 points
Date : 2024-08-08 07:24 UTC (15 hours ago)
(HTM) web link (pytorch.org)
(TXT) w3m dump (pytorch.org)
| andy12_ wrote:
| This is so cool. I want to try to implement something with this
| right now.
| gchamonlive wrote:
| Always had the curiosity to put something together with pytorch
| but it always seemed either a steep learning curve or there
| wasn't a big motivator (project, problem to solve, something in
| my daily routine to optimize).
|
| Does anybody have a good starting point to learn with hands-on
| projects and also that could accommodate for flexattention?
| ryneandal wrote:
| IMO the PyTorch getting started tutorials are really good
| (https://pytorch.org/tutorials/beginner/basics/intro.html).
|
| A classifier for handwritten digits in the MNIST dataset is
| generally considered the "Hello World" of neural networks. I
| went over it in a course, but there are countless tutorials to
| be found online, i.e.
| https://www.digitalocean.com/community/tutorials/introductio...
|
| Once you begin to understand how to handle data and how to
| define layers, you can start playing around with whatever your
| heart desires. The rabbit hole is vast and endless :)
| jisaacso wrote:
| Agreed that PyTorch tutorials are a great place to start.
| Specific to flexattention, the blog references the accompanying
| attention gym, which has a series of examples of how to use
| flex: https://github.com/pytorch-labs/attention-gym/
| sva_ wrote:
| Check Out Kaggle for the challenges
| brrrrrm wrote:
| For most LLM workloads today (short text chats), hundreds or a
| couple thousand tokens suffice. attention mechanisms don't
| dominate (< 30% compute). But as the modalities inevitably grow,
| work in attention approximation/compression is going to be
| paramount.
|
| Nice to see Pytorch already elegantly supporting this next step
| in research
| visarga wrote:
| It's interesting that optimizing a computation that can be
| described in a single line of math takes so much work. It took
| forever even to discover Flash attention. And in the 6 years
| since transformers were invented, thousands of papers worked on
| making it faster.
|
| Attention(Q,K,V) = Softmax(Q*K^T/sqrt(d_k))*V
|
| FlexAttention seems to have found the right abstraction for the
| task.
| d3m0t3p wrote:
| Yea, because the math have stripped down the whole thing to : I
| have data I do operation on them. while in reality we deal with
| multi head attention / grouped query and the positional
| encoding.
|
| That's all without taking into account the broadcasting done on
| the batch dimension
| chillee wrote:
| I would agree with this. For example, how would you represent
| causal attention in the standard equation?
| brrrrrm wrote:
| this is true of even just matrix multiplication (A*B) of which
| attention has two
| chillee wrote:
| Hi, one of the authors of this blog post (Horace He), along with
| Driss Guessous, Yanbo Liang, and Joy Dong.
|
| We're quite happy with this abstraction - happy to answer any
| questions about it!
| zaptrem wrote:
| For those of us using the 2D NATTEN kernel from their library
| along with torch.compile, is this faster? Especially given all
| their tricks (e.g., the non-deterministic KV-parallelism)
| chillee wrote:
| In my (very amateurish) testing, I think the performance
| seemed pretty comparable (for non-dilated natten). I need to
| do some proper benchmarking though!
| barrenko wrote:
| Can someone do a short summary or TL;DR for this?
| chillee wrote:
| https://x.com/chhillee/status/1821253769147118004?s=46
|
| Perhaps this tweet thread would be better.
| sva_ wrote:
| https://nitter.poast.org/chhillee/status/1821253769147118004
| barrenko wrote:
| Thanks, just weaned myself of Twitter / X.
| alecco wrote:
| > FlexAttention achieves 90% of FlashAttention2's performance in
| the forward pass and 85% in the backward pass.
|
| It's very good. But note FlashAttention-3 is 1.5x - 2x faster
| than FlashAttention-2.
| chillee wrote:
| These benchmarks are on Ampere, where FA3 has no performance
| benefits over FA2.
|
| On Hopper, FlexAttention is currently about 80% of
| FlashAttention3's performance (about 500 TFLOPs peak)
___________________________________________________________________
(page generated 2024-08-08 23:00 UTC)