FlashAttention-3, released in 2024, brings significant architectural optimizations for NVIDIA's Hopper GPUs (H100). This guide compares the two versions and helps you decide when upgrading makes sense.
Version Overview
| Feature | FlashAttention-2 | FlashAttention-3 |
|---|---|---|
| Release | July 2023 | July 2024 |
| Target Architecture | Ampere (A100) | Hopper (H100) |
| Peak Performance | 230 TFLOPS (A100) | 740 TFLOPS (H100) |
| FP8 Support | No | Yes |
| Asynchronous Ops | Limited | Full TMA support |
| Min CUDA | 11.6 | 12.3 |
What's New in FlashAttention-3
1. Warp Group Matrix Multiply (WGMMA)
FlashAttention-3 leverages Hopper's WGMMA instructions, which operate at warp-group scope (4 warps = 128 threads) instead of warp scope (32 threads):
FlashAttention-2: warp-level mma.sync instructions
FlashAttention-3: warp-group-level wgmma.mma_async instructions
WGMMA enables:
- Larger tile sizes: 128×128 or 256×64 vs 64×64
- Better register utilization: Distributes data across more threads
- Higher arithmetic intensity: More compute per memory access
2. Tensor Memory Accelerator (TMA)
H100 introduces TMA for hardware-accelerated asynchronous data movement:
// FlashAttention-2: Manual async copy
cp.async.ca.shared.global [smem_ptr], [gmem_ptr], 16;
cp.async.commit_group;
cp.async.wait_group 0;
// FlashAttention-3: TMA descriptor-based copy
cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes
[smem_ptr], [tensor_map, coords], [mbar];
TMA benefits:
- Address generation offload: GPU computes addresses in hardware
- Multicast support: Single memory read broadcasts to multiple SMs
- Better overlap: True asynchronous execution with compute
3. FP8 Precision Support
FlashAttention-3 is the first Flash Attention version with native FP8:
| Precision | Throughput (H100) | Memory | Use Case |
|---|---|---|---|
| FP16 | 740 TFLOPS | 2 bytes | Training default |
| BF16 | 740 TFLOPS | 2 bytes | Training (better range) |
| FP8 (E4M3) | 1,480 TFLOPS | 1 byte | Inference |
| FP8 (E5M2) | 1,480 TFLOPS | 1 byte | Training gradients |
FP8 doubles theoretical throughput but requires careful handling:
from flash_attn import flash_attn_func
# FP8 attention (FlashAttention-3 only)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
output = flash_attn_func(q_fp8, k_fp8, v_fp8, causal=True)
4. Ping-Pong Scheduling
FlashAttention-3 implements producer-consumer pipelining between warp groups:
Warp Group 0: Compute tile [i] | Load tile [i+2]
Warp Group 1: Load tile [i+1] | Compute tile [i-1]
This hides memory latency by overlapping:
- WGMMA compute operations
- TMA memory loads
- Softmax and other non-matmul operations
Performance Comparison
H100 SXM5 Benchmarks
Forward pass, FP16, batch=8, heads=32, d=128:
| Sequence | FA2 (TFLOPS) | FA3 (TFLOPS) | Speedup |
|---|---|---|---|
| 512 | 452 | 612 | 1.35x |
| 1024 | 498 | 685 | 1.38x |
| 2048 | 512 | 721 | 1.41x |
| 4096 | 508 | 738 | 1.45x |
| 8192 | 501 | 742 | 1.48x |
| 16384 | 495 | 739 | 1.49x |
A100 vs H100 with Each Version
| Setup | TFLOPS | % of Peak |
|---|---|---|
| A100 + FA2 | 225 | 72% |
| H100 + FA2 | 508 | 51% |
| H100 + FA3 | 740 | 75% |
FlashAttention-2 underutilizes H100 because it wasn't designed for Hopper's architecture. FlashAttention-3 achieves similar utilization on H100 as FA2 does on A100.
FP8 Performance
With FP8, FlashAttention-3 on H100 achieves even higher throughput:
| Precision | FA3 TFLOPS | vs FP16 FA3 |
|---|---|---|
| FP16 | 740 | 1.0x |
| BF16 | 740 | 1.0x |
| FP8 E4M3 | 1,200 | 1.62x |
Note: FP8 numbers are for inference. Training with FP8 requires master weights in higher precision.
Backward Compatibility
What Stays the Same
FlashAttention-3 maintains API compatibility with FlashAttention-2:
# Works with both FA2 and FA3
from flash_attn import flash_attn_func
output = flash_attn_func(
q, k, v,
causal=True,
softmax_scale=None,
dropout_p=0.0
)
What Changes
- Import path: Same (
flash_attn) - Function signatures: Same
- Output format: Same
[batch, seq, heads, dim] - Numerical results: Identical to FA2 (within FP precision)
The upgrade is typically a drop-in replacement—just install the new version.
When to Upgrade to FlashAttention-3
Upgrade If:
- Using H100/H200 GPUs: FA3 is specifically optimized for Hopper
- Need FP8 attention: Only available in FA3
- Training at scale: 1.4-1.5x speedup compounds significantly
- Long sequences: Benefits increase with sequence length
Stay with FlashAttention-2 If:
- Using Ampere or older: FA3 requires Hopper (sm_90)
- CUDA < 12.3: FA3 needs newer CUDA toolkit
- Stability concerns: FA2 is more battle-tested
- Non-standard patterns: FA3 may have reduced feature set initially
Hardware Requirements
| Version | Min Compute Capability | Min CUDA | Supported GPUs |
|---|---|---|---|
| FA2 | 7.0 (Volta) | 11.6 | V100, T4, A100, H100 |
| FA3 | 9.0 (Hopper) | 12.3 | H100, H200 only |
Installation
FlashAttention-2
# From prebuilt wheel (recommended)
pip install flash-attn
# Or find exact wheel at flashattn.dev
pip install https://github.com/.../flash_attn-2.x.x+cuXXX-cp3XX-linux_x86_64.whl
FlashAttention-3
# FA3 requires Hopper GPU and CUDA 12.3+
pip install flash-attn --pre # Latest including FA3
# Or specify version
pip install flash-attn>=3.0.0
Checking Your Version
import flash_attn
print(flash_attn.__version__)
# Check Hopper support
import torch
print(f"CUDA compute capability: {torch.cuda.get_device_capability()}")
# Needs (9, 0) for FA3
Migration Considerations
Code Changes
Most code requires no changes:
# This works with both FA2 and FA3
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
# Standard usage
output = flash_attn_func(q, k, v, causal=True)
# Packed QKV
qkv = torch.stack([q, k, v], dim=2) # [batch, seq, 3, heads, dim]
output = flash_attn_qkvpacked_func(qkv, causal=True)
FP8 Migration
To leverage FP8 in FA3:
# Step 1: Quantize inputs (use per-tensor or per-channel scaling)
from torch.ao.quantization import quantize_dynamic
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
# Step 2: Run attention
output_fp8 = flash_attn_func(q_fp8, k_fp8, v_fp8, causal=True)
# Step 3: Upcast output for downstream ops
output = output_fp8.to(torch.float16)
Performance Tuning
FA3-specific optimizations:
# Larger head dimensions work better with FA3's larger tiles
# FA2 optimal: d=64
# FA3 optimal: d=128
# If using d=64, consider padding to d=128 for FA3
if q.shape[-1] == 64 and is_hopper:
q = F.pad(q, (0, 64))
k = F.pad(k, (0, 64))
v = F.pad(v, (0, 64))
output = flash_attn_func(q, k, v)[..., :64]
Real-World Impact
Training Time Reduction
For a 7B parameter model training on 8x H100:
| Configuration | Time/Step | Daily Tokens |
|---|---|---|
| FA2 on H100 | 1.8s | 3.8T |
| FA3 on H100 | 1.3s | 5.3T |
| FA3 + FP8 | 0.95s | 7.2T |
Cost Implications
At $3.50/hour for H100 spot instances:
| Setup | Tokens/$ | Cost for 1T tokens |
|---|---|---|
| FA2 | 1.09T | $917 |
| FA3 | 1.51T | $662 |
| FA3 + FP8 | 2.06T | $485 |
FA3 with FP8 reduces training costs by nearly 50% compared to FA2.
Future Outlook
FlashAttention continues active development:
- FA3.x: Ongoing optimizations for Hopper
- Blackwell support: Expected for next-gen GPUs (B100, B200)
- Broader FP8: Training support improvements
- New attention patterns: Ring attention, sparse patterns
References
-
Shah, J., et al. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." arXiv:2407.08608
-
Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." arXiv:2307.08691
-
NVIDIA. (2025). "Hopper Architecture White Paper." NVIDIA Documentation
-
NVIDIA. (2025). "CUDA C++ Programming Guide: Warp Group MMA." NVIDIA Documentation