Back to all articles
Flash Attention

The Complete Guide to Flash Attention: How It Works and Why It Matters

Deep dive into Flash Attention's IO-aware algorithm, memory hierarchy optimization, and why it delivers 2-4x speedups. Covers FlashAttention-1, 2, and 3 with benchmarks and implementation details.

Flash Attention TeamJanuary 8, 202610 min read
flash attentiontransformer optimizationattention mechanismCUDAGPU memoryLLM training

Flash Attention has fundamentally changed how we train and run large language models. By rethinking how attention computations interact with GPU memory hierarchy, Tri Dao and his team at Stanford achieved 2-4x speedups and 5-20x memory reduction compared to standard attention implementations. This guide explains exactly how it works, when to use it, and how to get the most out of it.

What is Flash Attention?

Flash Attention is an IO-aware exact attention algorithm that computes the same mathematical result as standard scaled dot-product attention, but reorganizes the computation to minimize memory reads and writes between GPU High Bandwidth Memory (HBM) and on-chip SRAM.

The key insight is simple but profound: attention's bottleneck isn't compute, it's memory bandwidth. Modern GPUs like the A100 have 312 TFLOPS of compute but only 2 TB/s of HBM bandwidth. Standard attention implementations read and write the full N×N attention matrix to HBM, wasting most of GPU compute waiting on memory.

Flash Attention solves this by:

  • Tiling: Computing attention in blocks that fit entirely in SRAM
  • Kernel fusion: Combining multiple operations into a single GPU kernel
  • Recomputation: Trading minimal extra FLOPs to avoid storing large intermediate matrices

The result is mathematically identical outputs with dramatically better hardware utilization.

The Memory Hierarchy Problem

To understand Flash Attention, you need to understand GPU memory architecture:

Memory TypeSizeBandwidthLatency
HBM (Global Memory)40-80 GB1.5-3 TB/s~400 cycles
L2 Cache40-50 MB~4 TB/s~200 cycles
SRAM (Shared Memory)192 KB per SM~19 TB/s~30 cycles
Registers256 KB per SMPeak compute1 cycle

Standard attention computes Q @ K^T, stores the N×N matrix to HBM, applies softmax, stores again, then computes attention @ V. For a sequence length of 4096, that N×N matrix alone is 64 MB in FP16—far exceeding SRAM capacity.

Flash Attention never materializes this full matrix. Instead, it processes tiles of size approximately sqrt(SRAM_size), keeping all intermediate results in fast on-chip memory.

How Flash Attention Works

The Standard Attention Formula

Standard scaled dot-product attention computes:

Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V

Where:

  • Q (queries): shape [N, d]
  • K (keys): shape [N, d]
  • V (values): shape [N, d]
  • d_k: head dimension (typically 64-128)
  • N: sequence length

The Flash Attention Algorithm

Flash Attention processes attention in tiles. For each tile of queries Q_i:

  1. Load Q tile from HBM to SRAM
  2. For each K, V tile (inner loop):
    • Load K_j, V_j tiles to SRAM
    • Compute S_ij = Q_i @ K_j^T (stays in SRAM)
    • Track running max for numerical stability
    • Compute local softmax and update running output
  3. Write final output O_i back to HBM

The key innovation is the online softmax algorithm that allows computing exact softmax incrementally across tiles without storing the full attention matrix.

# Pseudocode for Flash Attention forward pass
def flash_attention_forward(Q, K, V, block_size):
    N, d = Q.shape
    O = zeros(N, d)
    L = zeros(N)  # log-sum-exp for softmax normalization

    for i in range(0, N, block_size):  # Query tiles
        Q_i = Q[i:i+block_size]
        O_i = zeros(block_size, d)
        l_i = zeros(block_size)
        m_i = full(block_size, -inf)  # Running max

        for j in range(0, N, block_size):  # Key/Value tiles
            K_j = K[j:j+block_size]
            V_j = V[j:j+block_size]

            # Compute attention scores for this tile
            S_ij = Q_i @ K_j.T / sqrt(d)

            # Online softmax update
            m_new = maximum(m_i, S_ij.max(dim=-1))
            P_ij = exp(S_ij - m_new.unsqueeze(-1))
            l_new = exp(m_i - m_new) * l_i + P_ij.sum(dim=-1)

            # Update output with rescaling
            O_i = exp(m_i - m_new).unsqueeze(-1) * O_i + P_ij @ V_j

            m_i = m_new
            l_i = l_new

        # Final normalization
        O[i:i+block_size] = O_i / l_i.unsqueeze(-1)
        L[i:i+block_size] = m_i + log(l_i)

    return O, L

Memory Complexity Analysis

AlgorithmMemoryIO Complexity
Standard AttentionO(N²)O(N² + Nd)
Flash AttentionO(N)O(N²d / M)

Where M is SRAM size (~100KB). For typical values (N=4096, d=64, M=100KB), Flash Attention reduces memory reads/writes by ~10x.

FlashAttention Versions Compared

FlashAttention-1 (2022)

The original implementation introduced the core tiled algorithm with online softmax. Key characteristics:

  • Forward pass: Full tiling with recomputation
  • Backward pass: Recomputes attention weights from stored logsumexp
  • Performance: 2-4x faster than PyTorch standard attention
  • Memory: Linear in sequence length

FlashAttention-2 (2023)

Major algorithmic and implementation improvements:

  • Parallelism: Split across sequence length, not just batch and heads
  • Work partitioning: Better distribution across thread blocks and warps
  • Reduced non-matmul FLOPs: Optimized softmax and masking operations
  • Performance: 2x faster than FlashAttention-1, reaching 50-73% of theoretical FLOPS

Key optimization: FlashAttention-2 parallelizes over the sequence dimension in the outer loop, enabling better GPU occupancy for long sequences.

FlashAttention-3 (2024)

Targets Hopper architecture (H100) with new hardware features:

  • WGMMA instructions: Tensor Core operations with warp-group scope
  • TMA: Tensor Memory Accelerator for async data movement
  • FP8 support: Native 8-bit floating point for 2x throughput
  • Pipelining: Overlapped compute and memory operations
  • Performance: Up to 740 TFLOPS on H100 (75% utilization)

Performance Benchmarks

Benchmark results from the official FlashAttention papers on A100 80GB:

Sequence LengthStandard AttentionFlashAttention-2Speedup
512142 TFLOPS194 TFLOPS1.37x
1024168 TFLOPS217 TFLOPS1.29x
2048135 TFLOPS223 TFLOPS1.65x
409674 TFLOPS222 TFLOPS3.0x
8192OOM218 TFLOPS
16384OOM210 TFLOPS

Memory usage for forward + backward pass (batch size 8, 12 heads, d=64):

Sequence LengthStandardFlashAttention-2Reduction
20482.1 GB0.4 GB5.3x
40968.4 GB0.7 GB12x
8192OOM1.3 GB-

When to Use Flash Attention

Always Use Flash Attention For:

  • Long sequences (>1024 tokens): Memory savings become dramatic
  • Training: Both forward and backward pass benefit
  • Memory-constrained scenarios: Enables larger batch sizes or longer contexts
  • Production inference: Lower latency and higher throughput

Consider Standard Attention When:

  • Very short sequences (<256 tokens): Overhead may not be worth it
  • Custom attention patterns: Some sparse patterns need specialized implementations
  • Debugging: Standard attention is easier to inspect intermediate values
  • Non-CUDA hardware: Flash Attention requires NVIDIA GPUs with compute capability 7.0+

Installation and Usage

Installing Prebuilt Wheels

The fastest way to install flash-attn is using prebuilt wheels:

# Find your wheel at flashattn.dev, then:
pip install flash-attn --no-build-isolation

# Or use uv for faster installation:
uv pip install flash-attn

Building from source requires CUDA toolkit and takes 30+ minutes. Prebuilt wheels install in seconds.

Basic Usage with PyTorch

from flash_attn import flash_attn_func

# Input tensors: [batch, seqlen, nheads, headdim]
q = torch.randn(2, 1024, 12, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 1024, 12, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 1024, 12, 64, device='cuda', dtype=torch.float16)

# Flash Attention forward pass
output = flash_attn_func(q, k, v, causal=True)

Integration with Hugging Face Transformers

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2"  # Enable Flash Attention
)

Key Parameters

ParameterDescriptionDefault
causalApply causal mask for autoregressive modelsFalse
softmax_scaleCustom scaling factor (default 1/sqrt(d))None
dropout_pDropout probability during training0.0
window_sizeSliding window for local attention(-1, -1)
alibi_slopesALiBi positional encoding slopesNone

Common Issues and Solutions

CUDA Version Mismatch

RuntimeError: CUDA error: no kernel image is available for execution

Solution: Ensure your flash-attn wheel matches your CUDA version. Check with:

python -c "import torch; print(torch.version.cuda)"

Out of Memory Despite Using Flash Attention

Flash Attention reduces attention memory, but other model components still use memory. Try:

  • Gradient checkpointing
  • Mixed precision training (FP16/BF16)
  • Reducing batch size
  • Using FSDP or DeepSpeed for model sharding

Numerical Differences

Flash Attention results may differ slightly from standard attention (within 1e-3) due to:

  • Different floating-point operation ordering
  • Optimized softmax computation

This is expected and doesn't affect model quality.

Advanced Topics

Causal vs Bidirectional Attention

Causal attention (used in GPT-style models) only attends to previous tokens. Flash Attention handles this efficiently by skipping computation for masked positions:

# Causal (autoregressive) attention
output = flash_attn_func(q, k, v, causal=True)

# Bidirectional (BERT-style) attention
output = flash_attn_func(q, k, v, causal=False)

Sliding Window Attention

For models like Mistral that use local attention windows:

# Sliding window of 4096 tokens
output = flash_attn_func(q, k, v, window_size=(4096, 4096))

Custom Attention Masks

Flash Attention supports block-sparse attention patterns through flash_attn_with_kvcache for more complex masking scenarios.

References

  1. Dao, T., Fu, D.Y., Ermon, S., Rudra, A., & Ré, C. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. arXiv:2205.14135

  2. Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." arXiv:2307.08691

  3. Shah, J., et al. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." arXiv:2407.08608

  4. Block-Sparse FlashAttention. (2025). "Block Sparse Flash Attention." arXiv:2512.07011

  5. Vaswani, A., et al. (2017). "Attention Is All You Need." NeurIPS 2017. arXiv:1706.03762

  6. NVIDIA. (2025). "CUDA C++ Programming Guide." NVIDIA Documentation

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel