Flash Attention has fundamentally changed how we train and run large language models. By rethinking how attention computations interact with GPU memory hierarchy, Tri Dao and his team at Stanford achieved 2-4x speedups and 5-20x memory reduction compared to standard attention implementations. This guide explains exactly how it works, when to use it, and how to get the most out of it.
What is Flash Attention?
Flash Attention is an IO-aware exact attention algorithm that computes the same mathematical result as standard scaled dot-product attention, but reorganizes the computation to minimize memory reads and writes between GPU High Bandwidth Memory (HBM) and on-chip SRAM.
The key insight is simple but profound: attention's bottleneck isn't compute, it's memory bandwidth. Modern GPUs like the A100 have 312 TFLOPS of compute but only 2 TB/s of HBM bandwidth. Standard attention implementations read and write the full N×N attention matrix to HBM, wasting most of GPU compute waiting on memory.
Flash Attention solves this by:
- Tiling: Computing attention in blocks that fit entirely in SRAM
- Kernel fusion: Combining multiple operations into a single GPU kernel
- Recomputation: Trading minimal extra FLOPs to avoid storing large intermediate matrices
The result is mathematically identical outputs with dramatically better hardware utilization.
The Memory Hierarchy Problem
To understand Flash Attention, you need to understand GPU memory architecture:
| Memory Type | Size | Bandwidth | Latency |
|---|---|---|---|
| HBM (Global Memory) | 40-80 GB | 1.5-3 TB/s | ~400 cycles |
| L2 Cache | 40-50 MB | ~4 TB/s | ~200 cycles |
| SRAM (Shared Memory) | 192 KB per SM | ~19 TB/s | ~30 cycles |
| Registers | 256 KB per SM | Peak compute | 1 cycle |
Standard attention computes Q @ K^T, stores the N×N matrix to HBM, applies softmax, stores again, then computes attention @ V. For a sequence length of 4096, that N×N matrix alone is 64 MB in FP16—far exceeding SRAM capacity.
Flash Attention never materializes this full matrix. Instead, it processes tiles of size approximately sqrt(SRAM_size), keeping all intermediate results in fast on-chip memory.
How Flash Attention Works
The Standard Attention Formula
Standard scaled dot-product attention computes:
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
Where:
- Q (queries): shape [N, d]
- K (keys): shape [N, d]
- V (values): shape [N, d]
- d_k: head dimension (typically 64-128)
- N: sequence length
The Flash Attention Algorithm
Flash Attention processes attention in tiles. For each tile of queries Q_i:
- Load Q tile from HBM to SRAM
- For each K, V tile (inner loop):
- Load K_j, V_j tiles to SRAM
- Compute S_ij = Q_i @ K_j^T (stays in SRAM)
- Track running max for numerical stability
- Compute local softmax and update running output
- Write final output O_i back to HBM
The key innovation is the online softmax algorithm that allows computing exact softmax incrementally across tiles without storing the full attention matrix.
# Pseudocode for Flash Attention forward pass
def flash_attention_forward(Q, K, V, block_size):
N, d = Q.shape
O = zeros(N, d)
L = zeros(N) # log-sum-exp for softmax normalization
for i in range(0, N, block_size): # Query tiles
Q_i = Q[i:i+block_size]
O_i = zeros(block_size, d)
l_i = zeros(block_size)
m_i = full(block_size, -inf) # Running max
for j in range(0, N, block_size): # Key/Value tiles
K_j = K[j:j+block_size]
V_j = V[j:j+block_size]
# Compute attention scores for this tile
S_ij = Q_i @ K_j.T / sqrt(d)
# Online softmax update
m_new = maximum(m_i, S_ij.max(dim=-1))
P_ij = exp(S_ij - m_new.unsqueeze(-1))
l_new = exp(m_i - m_new) * l_i + P_ij.sum(dim=-1)
# Update output with rescaling
O_i = exp(m_i - m_new).unsqueeze(-1) * O_i + P_ij @ V_j
m_i = m_new
l_i = l_new
# Final normalization
O[i:i+block_size] = O_i / l_i.unsqueeze(-1)
L[i:i+block_size] = m_i + log(l_i)
return O, L
Memory Complexity Analysis
| Algorithm | Memory | IO Complexity |
|---|---|---|
| Standard Attention | O(N²) | O(N² + Nd) |
| Flash Attention | O(N) | O(N²d / M) |
Where M is SRAM size (~100KB). For typical values (N=4096, d=64, M=100KB), Flash Attention reduces memory reads/writes by ~10x.
FlashAttention Versions Compared
FlashAttention-1 (2022)
The original implementation introduced the core tiled algorithm with online softmax. Key characteristics:
- Forward pass: Full tiling with recomputation
- Backward pass: Recomputes attention weights from stored logsumexp
- Performance: 2-4x faster than PyTorch standard attention
- Memory: Linear in sequence length
FlashAttention-2 (2023)
Major algorithmic and implementation improvements:
- Parallelism: Split across sequence length, not just batch and heads
- Work partitioning: Better distribution across thread blocks and warps
- Reduced non-matmul FLOPs: Optimized softmax and masking operations
- Performance: 2x faster than FlashAttention-1, reaching 50-73% of theoretical FLOPS
Key optimization: FlashAttention-2 parallelizes over the sequence dimension in the outer loop, enabling better GPU occupancy for long sequences.
FlashAttention-3 (2024)
Targets Hopper architecture (H100) with new hardware features:
- WGMMA instructions: Tensor Core operations with warp-group scope
- TMA: Tensor Memory Accelerator for async data movement
- FP8 support: Native 8-bit floating point for 2x throughput
- Pipelining: Overlapped compute and memory operations
- Performance: Up to 740 TFLOPS on H100 (75% utilization)
Performance Benchmarks
Benchmark results from the official FlashAttention papers on A100 80GB:
| Sequence Length | Standard Attention | FlashAttention-2 | Speedup |
|---|---|---|---|
| 512 | 142 TFLOPS | 194 TFLOPS | 1.37x |
| 1024 | 168 TFLOPS | 217 TFLOPS | 1.29x |
| 2048 | 135 TFLOPS | 223 TFLOPS | 1.65x |
| 4096 | 74 TFLOPS | 222 TFLOPS | 3.0x |
| 8192 | OOM | 218 TFLOPS | ∞ |
| 16384 | OOM | 210 TFLOPS | ∞ |
Memory usage for forward + backward pass (batch size 8, 12 heads, d=64):
| Sequence Length | Standard | FlashAttention-2 | Reduction |
|---|---|---|---|
| 2048 | 2.1 GB | 0.4 GB | 5.3x |
| 4096 | 8.4 GB | 0.7 GB | 12x |
| 8192 | OOM | 1.3 GB | - |
When to Use Flash Attention
Always Use Flash Attention For:
- Long sequences (>1024 tokens): Memory savings become dramatic
- Training: Both forward and backward pass benefit
- Memory-constrained scenarios: Enables larger batch sizes or longer contexts
- Production inference: Lower latency and higher throughput
Consider Standard Attention When:
- Very short sequences (<256 tokens): Overhead may not be worth it
- Custom attention patterns: Some sparse patterns need specialized implementations
- Debugging: Standard attention is easier to inspect intermediate values
- Non-CUDA hardware: Flash Attention requires NVIDIA GPUs with compute capability 7.0+
Installation and Usage
Installing Prebuilt Wheels
The fastest way to install flash-attn is using prebuilt wheels:
# Find your wheel at flashattn.dev, then:
pip install flash-attn --no-build-isolation
# Or use uv for faster installation:
uv pip install flash-attn
Building from source requires CUDA toolkit and takes 30+ minutes. Prebuilt wheels install in seconds.
Basic Usage with PyTorch
from flash_attn import flash_attn_func
# Input tensors: [batch, seqlen, nheads, headdim]
q = torch.randn(2, 1024, 12, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 1024, 12, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 1024, 12, 64, device='cuda', dtype=torch.float16)
# Flash Attention forward pass
output = flash_attn_func(q, k, v, causal=True)
Integration with Hugging Face Transformers
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2" # Enable Flash Attention
)
Key Parameters
| Parameter | Description | Default |
|---|---|---|
causal | Apply causal mask for autoregressive models | False |
softmax_scale | Custom scaling factor (default 1/sqrt(d)) | None |
dropout_p | Dropout probability during training | 0.0 |
window_size | Sliding window for local attention | (-1, -1) |
alibi_slopes | ALiBi positional encoding slopes | None |
Common Issues and Solutions
CUDA Version Mismatch
RuntimeError: CUDA error: no kernel image is available for execution
Solution: Ensure your flash-attn wheel matches your CUDA version. Check with:
python -c "import torch; print(torch.version.cuda)"
Out of Memory Despite Using Flash Attention
Flash Attention reduces attention memory, but other model components still use memory. Try:
- Gradient checkpointing
- Mixed precision training (FP16/BF16)
- Reducing batch size
- Using FSDP or DeepSpeed for model sharding
Numerical Differences
Flash Attention results may differ slightly from standard attention (within 1e-3) due to:
- Different floating-point operation ordering
- Optimized softmax computation
This is expected and doesn't affect model quality.
Advanced Topics
Causal vs Bidirectional Attention
Causal attention (used in GPT-style models) only attends to previous tokens. Flash Attention handles this efficiently by skipping computation for masked positions:
# Causal (autoregressive) attention
output = flash_attn_func(q, k, v, causal=True)
# Bidirectional (BERT-style) attention
output = flash_attn_func(q, k, v, causal=False)
Sliding Window Attention
For models like Mistral that use local attention windows:
# Sliding window of 4096 tokens
output = flash_attn_func(q, k, v, window_size=(4096, 4096))
Custom Attention Masks
Flash Attention supports block-sparse attention patterns through flash_attn_with_kvcache for more complex masking scenarios.
References
-
Dao, T., Fu, D.Y., Ermon, S., Rudra, A., & Ré, C. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. arXiv:2205.14135
-
Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." arXiv:2307.08691
-
Shah, J., et al. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." arXiv:2407.08608
-
Block-Sparse FlashAttention. (2025). "Block Sparse Flash Attention." arXiv:2512.07011
-
Vaswani, A., et al. (2017). "Attention Is All You Need." NeurIPS 2017. arXiv:1706.03762
-
NVIDIA. (2025). "CUDA C++ Programming Guide." NVIDIA Documentation