https://github.com/tspeterkim/flash-attention-minimal Skip to content Toggle navigation Sign in * Product + Actions Automate any workflow + Packages Host and manage packages + Security Find and fix vulnerabilities + Codespaces Instant dev environments + Copilot Write better code with AI + Code review Manage code changes + Issues Plan and track work + Discussions Collaborate outside of code Explore + All features + Documentation + GitHub Skills + Blog * Solutions For + Enterprise + Teams + Startups + Education By Solution + CI/CD & Automation + DevOps + DevSecOps Resources + Learning Pathways + White papers, Ebooks, Webinars + Customer Stories + Partners * Open Source + GitHub Sponsors Fund open source developers + The ReadME Project GitHub community articles Repositories + Topics + Trending + Collections * Pricing Search or jump to... Search code, repositories, users, issues, pull requests... Search [ ] Clear Search syntax tips Provide feedback We read every piece of feedback, and take your input very seriously. [ ] [ ] Include my email address so I can be contacted Cancel Submit feedback Saved searches Use saved searches to filter your results more quickly Name [ ] Query [ ] To see all available qualifiers, see our documentation. Cancel Create saved search Sign in Sign up You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session. You switched accounts on another tab or window. Reload to refresh your session. Dismiss alert {{ message }} tspeterkim / flash-attention-minimal Public * Notifications * Fork 5 * Star 151 * Flash Attention in ~100 lines of CUDA (forward pass only) License Apache-2.0 license 151 stars 5 forks Branches Tags Activity Star Notifications * Code * Issues 1 * Pull requests 0 * Actions * Projects 0 * Security * Insights Additional navigation options * Code * Issues * Pull requests * Actions * Projects * Security * Insights tspeterkim/flash-attention-minimal This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository. main BranchesTags Go to file Code Folders and files Name Name Last commit message Last commit date Latest commit History 8 Commits LICENSE LICENSE README.md README.md bench.py bench.py flash.cu flash.cu main.cpp main.cpp View all files Repository files navigation * README * Apache-2.0 license flash-attention-minimal A minimal re-implementation of Flash Attention with CUDA and PyTorch. The official implementation can be quite daunting for a CUDA beginner (like myself), so this repo tries to be small and educational. * The entire forward pass is written in ~100 lines in flash.cu. * The variable names follow the notations from the original paper. Usage Prerequisite * PyTorch (with CUDA) * Ninja for loading in C++ Benchmark Compare the wall-clock time between manual attention and minimal flash attention: python bench.py Sample output on a T4: === profiling manual attention === ... Self CPU time total: 52.389ms Self CUDA time total: 52.545ms === profiling minimal flash attention === ... Self CPU time total: 11.452ms Self CUDA time total: 3.908ms Speed-up achieved! I don't have a GPU Try out this online colab demo. Caveats * No backward pass! To be honest, I found it a lot more complex than the forward pass, which was enough to show the use of shared memory to avoid large N^2 read/writes. * In the inner loop, I assign each thread to a row of the output matrix. This differs from the original implementation. * This thread-per-row simplification makes the matrix multiplications very slow. This is probably why for longer sequences and larger block sizes, this gets slower than the manual implementation. * Q,K,Vs are in float32, unlike the original implementation which uses float16. * The block size is fixed at compile time to 32. Todos * [ ] Add backward pass * [ ] Speed up matmults * [ ] Dynamically set block size About Flash Attention in ~100 lines of CUDA (forward pass only) Resources Readme License Apache-2.0 license Activity Stars 151 stars Watchers 1 watching Forks 5 forks Report repository Releases No releases published Packages 0 No packages published Languages * Cuda 73.8% * Python 22.7% * C++ 3.5% Footer (c) 2024 GitHub, Inc. Footer navigation * Terms * Privacy * Security * Status * Docs * Contact * Manage cookies * Do not share my personal information You can't perform that action at this time.