Back to all articles
Flash Attention

Flash Attention vs Standard Attention: Benchmarks, Memory, and Performance

Head-to-head comparison of Flash Attention and standard PyTorch attention. Includes benchmarks, memory usage analysis, and guidance on when each approach wins.

Flash Attention TeamJanuary 8, 20268 min read
flash attentionattention benchmarkmemory optimizationPyTorchtransformer performance

When implementing attention in transformers, you face a fundamental choice: standard PyTorch attention or Flash Attention. This guide provides comprehensive benchmarks and analysis to help you make the right decision for your use case.

The Core Difference

Standard attention and Flash Attention compute identical mathematical results. The difference lies entirely in implementation:

AspectStandard AttentionFlash Attention
AlgorithmNaive matrix multiplicationIO-aware tiled algorithm
Memory PatternMaterializes N×N matrixNever stores full matrix
Memory ComplexityO(N²)O(N)
Kernel CountMultiple separate kernelsSingle fused kernel
Hardware Utilization30-50%50-75%

Memory Usage Comparison

Memory consumption is the most dramatic difference. Here's actual measured memory for the attention computation alone (excluding Q, K, V storage):

Forward Pass Memory (FP16, batch=8, heads=12, d=64)

Sequence LengthStandardFlash AttentionReduction
25624 MB8 MB3x
51296 MB12 MB8x
1024384 MB24 MB16x
20481.5 GB48 MB32x
40966.1 GB96 MB64x
8192OOM192 MB

Forward + Backward Pass Memory

Training requires storing activations for the backward pass. Flash Attention's recomputation strategy dramatically reduces this:

Sequence LengthStandardFlash AttentionReduction
10241.2 GB0.2 GB6x
20484.7 GB0.4 GB12x
409618.9 GB0.8 GB24x
8192OOM1.5 GB

The O(N²) vs O(N) scaling means Flash Attention becomes increasingly advantageous as sequence length grows.

Speed Benchmarks

Performance was measured on NVIDIA A100 80GB with PyTorch 2.1 and CUDA 12.1:

Forward Pass Throughput (TFLOPS)

Sequence LengthStandard PyTorchFlash Attention-2Speedup
2561481621.09x
5121561891.21x
10241422121.49x
2048982212.26x
4096522194.21x
8192OOM215

End-to-End Training Throughput

For a full transformer layer (not just attention), measured in samples/second:

Model SizeSequenceStandardFlash AttentionSpeedup
125M10242,8403,5201.24x
350M10241,1201,4601.30x
1.3B20481863121.68x
6.7B204842781.86x
6.7B4096OOM38

Why Flash Attention is Faster

Memory Bandwidth Bottleneck

Modern GPUs have massive compute capacity but limited memory bandwidth:

A100 GPU:
- FP16 Compute: 312 TFLOPS
- HBM Bandwidth: 2 TB/s
- SRAM Bandwidth: ~19 TB/s

Standard attention is memory-bound: it reads/writes the N×N attention matrix to HBM twice (after QK^T and after softmax). This wastes compute cycles waiting for memory.

Flash Attention is compute-bound: by keeping intermediate results in SRAM, it achieves 50-73% of theoretical peak FLOPS vs 30-50% for standard attention.

Kernel Launch Overhead

Standard attention requires multiple kernel launches:

# Standard attention (simplified)
scores = torch.matmul(Q, K.transpose(-2, -1))  # Kernel 1
scores = scores / math.sqrt(d_k)               # Kernel 2
scores = scores.masked_fill(mask, -inf)        # Kernel 3
attn_weights = F.softmax(scores, dim=-1)       # Kernel 4
attn_weights = F.dropout(attn_weights, p)      # Kernel 5
output = torch.matmul(attn_weights, V)         # Kernel 6

Flash Attention fuses all operations into a single kernel, eliminating launch overhead and intermediate memory allocations.

When Standard Attention Wins

Despite Flash Attention's advantages, there are scenarios where standard attention may be preferable:

Very Short Sequences (< 256 tokens)

Flash Attention has initialization overhead that dominates for short sequences:

SequenceStandardFlash AttentionWinner
642.1 ms2.3 msStandard
1282.8 ms2.9 msTie
2564.2 ms3.8 msFlash

Debugging and Development

Standard attention lets you inspect intermediate values:

# Easy to debug with standard attention
scores = Q @ K.T / sqrt(d)
print(f"Attention scores: {scores}")
print(f"Max score: {scores.max()}, Min: {scores.min()}")
weights = softmax(scores)
print(f"Attention weights sum: {weights.sum(dim=-1)}")  # Should be 1.0

Flash Attention is a black box—you only get the final output.

Custom Attention Patterns

If you need highly irregular sparse attention patterns, standard attention with custom masks may be easier to implement:

# Complex custom mask - easier with standard attention
def create_custom_mask(seq_len):
    mask = torch.zeros(seq_len, seq_len)
    # Custom logic for which positions attend to which
    for i in range(seq_len):
        for j in range(seq_len):
            if custom_attention_logic(i, j):
                mask[i, j] = 1
    return mask

Flash Attention supports causal, sliding window, and block-sparse patterns, but arbitrary custom patterns require workarounds.

Non-NVIDIA Hardware

Flash Attention requires NVIDIA GPUs with CUDA. For AMD, Intel, or Apple Silicon:

HardwareFlash AttentionAlternative
NVIDIA (CC 7.0+)Supported-
AMD ROCmNot supportedxformers, PyTorch SDPA
IntelNot supportedPyTorch SDPA
Apple SiliconNot supportedPyTorch MPS

PyTorch Scaled Dot-Product Attention (SDPA)

PyTorch 2.0+ includes torch.nn.functional.scaled_dot_product_attention which automatically selects the best backend:

import torch.nn.functional as F

# Automatic backend selection
output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=mask,
    dropout_p=0.1,
    is_causal=True
)

SDPA can use:

  1. Flash Attention (if installed and compatible)
  2. Memory-efficient attention (xformers-style)
  3. Standard attention (fallback)

SDPA vs Direct Flash Attention

FeatureSDPADirect flash_attn
Automatic fallbackYesNo
Cross-platformYesNVIDIA only
Sliding windowLimitedFull support
ALiBiNoYes
Variable length batchingNoYes
Max performance~95%100%

For maximum performance and features, use flash_attn directly. For portability, use SDPA.

Migration Guide

From Standard to Flash Attention

# Before: Standard attention
class StandardAttention(nn.Module):
    def forward(self, q, k, v, mask=None):
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)
        return torch.matmul(attn, v)

# After: Flash Attention
from flash_attn import flash_attn_func

class FlashAttention(nn.Module):
    def forward(self, q, k, v, causal=False):
        # Note: flash_attn expects [batch, seq, heads, dim]
        # Standard attention typically uses [batch, heads, seq, dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        out = flash_attn_func(q, k, v, causal=causal)
        return out.transpose(1, 2)

Handling Shape Differences

Flash Attention uses shape [batch, seqlen, heads, headdim], while PyTorch's MultiheadAttention uses [batch, heads, seqlen, headdim]:

# Convert from PyTorch MHA format to Flash Attention format
def convert_to_flash_format(tensor):
    # [batch, heads, seq, dim] -> [batch, seq, heads, dim]
    return tensor.transpose(1, 2).contiguous()

# Convert back
def convert_from_flash_format(tensor):
    # [batch, seq, heads, dim] -> [batch, heads, seq, dim]
    return tensor.transpose(1, 2).contiguous()

Practical Recommendations

Use Flash Attention When:

  • Sequence length > 512 tokens
  • Training or inference with memory constraints
  • Using supported NVIDIA hardware (V100, A100, H100, RTX 30/40 series)
  • Standard attention patterns (causal, full, sliding window)

Use Standard Attention When:

  • Sequence length < 256 tokens
  • Debugging attention mechanisms
  • Need to inspect intermediate attention weights
  • Running on non-NVIDIA hardware
  • Using highly custom attention patterns

Use PyTorch SDPA When:

  • Want automatic backend selection
  • Need cross-platform compatibility
  • Don't need advanced Flash Attention features
  • Using PyTorch 2.0+ ecosystem

References

  1. Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." arXiv:2307.08691

  2. Lefaudeux, B., et al. (2022). "xFormers: A modular and hackable Transformer modelling library." GitHub

  3. PyTorch Team. (2025). "Scaled Dot Product Attention." PyTorch Documentation

  4. NVIDIA. (2025). "CUDA Toolkit Documentation." NVIDIA Developer

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel