Flash Attention Explained: FA1, FA2, FA3
Standard attention materializes an N×N matrix in GPU HBM, making it the memory bottleneck of every Transformer. Flash Attention computes the exact same result with tiling and recomputation, cutting memory from O(N²) to O(N) and delivering wall-clock speedups of 2-8x. Here is the mechanism, why the GPU memory hierarchy is the root cause, and how each version (FA1, FA2, FA3) improved on the last.
The problem: attention is memory-bound
Attention computes softmax(QKᵀ/√d)·V — a product that produces an N×N matrix (where N is the sequence length) before the softmax reduces it. Standard implementations write that full N×N matrix to GPU High Bandwidth Memory (HBM), then read it back for the softmax. At sequence length 8,192, that matrix is 8,192 × 8,192 × 4 bytes ≈ 268 MB per head — and most of the time is spent shuttling data between HBM and the GPU's fast SRAM, not doing math. This is what "memory-bound" means: the GPU's compute units starve waiting for memory transfers.
The GPU memory hierarchy is the root cause. An H100 has 80 GB of HBM at ~3 TB/s bandwidth, but only ~228 KB of shared memory (SRAM) per streaming multiprocessor at ~19 TB/s. The standard attention kernel dumps everything to slow HBM; the fix is to keep the work in fast SRAM. That is the entire insight behind Flash Attention.
FA1: tiling + recomputation
FlashAttention-1 (Dao et al., 2022) introduced two techniques. Tiling breaks the Q, K, V matrices into blocks small enough to fit in SRAM, computes partial attention scores block-by-block, and accumulates the softmax incrementally with running statistics — so the full N×N matrix never materializes in HBM. Recomputation handles the backward pass: instead of storing the attention matrix for gradients (which would defeat the memory savings), it recomputes it from Q, K, V during backprop, trading redundant compute for the memory win.
# Standard attention: materialize N×N in HBM
S = Q @ K.T # N×N matrix written to HBM
P = softmax(S) # read N×N, write N×N
O = P @ V # read N×N
# Flash Attention: tiled, never materialize N×N
for each block of Q, K, V:
load block into SRAM
compute partial scores + running max/sum
accumulate output incrementally
# Memory: O(N) instead of O(N²)
# Same math, identical result
The result: memory drops from O(N²) to O(N), and because the kernel spends far less time on HBM transfers, it is 2-4x faster wall-clock despite doing slightly more total FLOPs (from recomputation). Critically, the output is mathematically identical to standard attention — this is exact computation, not an approximation.
[IMAGE: GPU memory hierarchy — HBM (80GB, slow) vs SRAM (228KB, fast); tiles moving between them]
FA2: better parallelism (2x over FA1)
FlashAttention-2 (Dao, 2023) kept the tiling and recomputation but dramatically improved GPU parallelism and work partitioning. FA1 underutilized the GPU because it parallelized poorly across thread blocks and warps. FA2 reorganized the work so that each thread block handles more computation with fewer synchronization points, and the work is partitioned to keep both the compute units and memory bandwidth busy. The outcome is roughly 2x speedup over FA1 on the same hardware — pure scheduling improvement, no algorithmic change to the attention math.
As the Stanford Hazy Research blog explains, FA2's gains came from reducing non-matmul FLOPs, better distributing work across the GPU's thread hierarchy, and avoiding the occupancy problems of FA1. By 2024, FA2 became the default attention kernel in virtually every major inference and training stack — PyTorch, vLLM, SGLang, Megatron, and the transformer libraries all ship it.
FA3: exploiting Hopper asynchrony (1.5-2x on H100)
FlashAttention-3 (Shah et al., 2024) targets the Hopper architecture (H100) specifically. The FA3 paper (arXiv:2407.08608) reports 1.5-2.0x speedup over FA2 on H100, reaching up to 740 TFLOPs/s in FP16 — 75% of the H100's peak. Three Hopper-specific techniques deliver the gains.
First, overlap of matmul and softmax: Hopper's async Tensor Cores let the GPU start the softmax on one thread block while matmuls run on another, hiding the non-matmul latency that FA2 could not. Second, Tensor Memory Accelerator (TMA): a dedicated unit that asynchronously loads data from HBM to shared memory, freeing the compute units from address calculation. Third, FP8 support: FA3 supports FP8 (FP8 e4m3 and e5m2) which doubles throughput again on Hopper's low-precision Tensor Cores, with the paper's numerical-analysis framework controlling error.
| Version | Key innovation | Speedup | HBM memory |
|---|---|---|---|
| Standard attention | — | Baseline | O(N²) |
| FA1 | Tiling + recomputation | 2-4x over standard | O(N) |
| FA2 | Better GPU parallelism | ~2x over FA1 | O(N) |
| FA3 | Hopper async + TMA + FP8 | 1.5-2x over FA2 (H100, 740 TFLOPs/s) | O(N) |
FA3's gains are H100-specific. On Ampere (A100) and older GPUs, FA2 remains the best choice — FA3's async Tensor Core and TMA techniques require Hopper. Always match the kernel to the GPU generation.
Why it is exact, not approximate
A persistent misconception is that Flash Attention trades accuracy for speed. It does not. The tiled block computation accumulates the softmax incrementally using the log-sum-exp trick, which produces bit-equivalent (within floating-point associativity) results to materializing the full matrix. The backward-pass recomputation recomputes the same values the forward pass used. The memory and speed gains come entirely from reordering the computation to fit the memory hierarchy, not from dropping or approximating any term. This is why every major model trains and serves with Flash Attention enabled by default — there is no quality tradeoff.
How it composes with other optimizations
Flash Attention is the foundation, not a competitor, to the other optimizations in this cluster. KV cache quantization compresses the K and V tensors that Flash Attention reads — FA3's FP8 path and INT4/FP8 KV cache quantization both reduce memory bandwidth pressure. Prefix caching reuses cached KV blocks, so Flash Attention only computes the divergent suffix. Speculative decoding runs many short draft-model attention calls, each benefiting from the FA kernel. MLA (Multi-Head Latent Attention) compresses the KV cache before Flash Attention processes it. The techniques are multiplicative, not alternatives.
FAQ
What does Flash Attention do?
Flash Attention computes exact attention without materializing the full N×N matrix in HBM. By tiling into SRAM-fitting blocks and recomputing in the backward pass, it reduces memory from O(N²) to O(N) while delivering 2-8x wall-clock speedups.
How much faster is FlashAttention-3 than FA2?
FlashAttention-3 delivers 1.5-2.0x speedup over FA2 on H100 GPUs, reaching up to 740 TFLOPs/s in FP16 (75% of H100 peak) via async Tensor Cores and TMA.
Does Flash Attention reduce model accuracy?
No. Flash Attention computes the exact same attention values as standard attention — it is mathematically equivalent, not an approximation. The gains come purely from reordering computation to fit the GPU memory hierarchy.
Related deep dives
- Prefix Caching in vLLM & SGLang — reuse the KV blocks Flash Attention computes
- KV Cache Quantization — compress the K, V tensors Flash Attention reads
- MLA Attention — compress the KV cache before Flash Attention processes it
- Speculative Decoding — each draft attention call uses the FA kernel
Sources
- Tri Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention," arXiv:2207.08608 (FA1)
- Tri Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning," 2023
- Jay Shah et al., "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision," arXiv:2407.08608 (740 TFLOPs/s, 1.5-2x over FA2)
- Stanford Hazy Research, "FlashAttention-2" blog, 2023
- Tri Dao, "FlashAttention-3" blog, 2024 (tridao.me)
- Lambda Blog, "How FlashAttention-2 Accelerates LLMs on NVIDIA H100 and A100," 2024
Speedup figures are hardware-dependent. FA3 gains require Hopper (H100); on Ampere and older, FA2 remains optimal. Verify kernel support for your specific GPU and model.