Back to all articles
Flash Attention

Flash Attention 2 vs Flash Attention 3: What's New and When to Upgrade

Detailed comparison of FlashAttention-2 and FlashAttention-3. Covers Hopper optimizations, FP8 support, performance gains, and migration considerations for H100 users.

Flash Attention TeamJanuary 8, 20268 min read
flash attention 2flash attention 3H100HopperFP8CUDA optimization

FlashAttention-3, released in 2024, brings significant architectural optimizations for NVIDIA's Hopper GPUs (H100). This guide compares the two versions and helps you decide when upgrading makes sense.

Version Overview

FeatureFlashAttention-2FlashAttention-3
ReleaseJuly 2023July 2024
Target ArchitectureAmpere (A100)Hopper (H100)
Peak Performance230 TFLOPS (A100)740 TFLOPS (H100)
FP8 SupportNoYes
Asynchronous OpsLimitedFull TMA support
Min CUDA11.612.3

What's New in FlashAttention-3

1. Warp Group Matrix Multiply (WGMMA)

FlashAttention-3 leverages Hopper's WGMMA instructions, which operate at warp-group scope (4 warps = 128 threads) instead of warp scope (32 threads):

FlashAttention-2: warp-level mma.sync instructions
FlashAttention-3: warp-group-level wgmma.mma_async instructions

WGMMA enables:

  • Larger tile sizes: 128×128 or 256×64 vs 64×64
  • Better register utilization: Distributes data across more threads
  • Higher arithmetic intensity: More compute per memory access

2. Tensor Memory Accelerator (TMA)

H100 introduces TMA for hardware-accelerated asynchronous data movement:

// FlashAttention-2: Manual async copy
cp.async.ca.shared.global [smem_ptr], [gmem_ptr], 16;
cp.async.commit_group;
cp.async.wait_group 0;

// FlashAttention-3: TMA descriptor-based copy
cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes
    [smem_ptr], [tensor_map, coords], [mbar];

TMA benefits:

  • Address generation offload: GPU computes addresses in hardware
  • Multicast support: Single memory read broadcasts to multiple SMs
  • Better overlap: True asynchronous execution with compute

3. FP8 Precision Support

FlashAttention-3 is the first Flash Attention version with native FP8:

PrecisionThroughput (H100)MemoryUse Case
FP16740 TFLOPS2 bytesTraining default
BF16740 TFLOPS2 bytesTraining (better range)
FP8 (E4M3)1,480 TFLOPS1 byteInference
FP8 (E5M2)1,480 TFLOPS1 byteTraining gradients

FP8 doubles theoretical throughput but requires careful handling:

from flash_attn import flash_attn_func

# FP8 attention (FlashAttention-3 only)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)

output = flash_attn_func(q_fp8, k_fp8, v_fp8, causal=True)

4. Ping-Pong Scheduling

FlashAttention-3 implements producer-consumer pipelining between warp groups:

Warp Group 0: Compute tile [i]   | Load tile [i+2]
Warp Group 1: Load tile [i+1]    | Compute tile [i-1]

This hides memory latency by overlapping:

  • WGMMA compute operations
  • TMA memory loads
  • Softmax and other non-matmul operations

Performance Comparison

H100 SXM5 Benchmarks

Forward pass, FP16, batch=8, heads=32, d=128:

SequenceFA2 (TFLOPS)FA3 (TFLOPS)Speedup
5124526121.35x
10244986851.38x
20485127211.41x
40965087381.45x
81925017421.48x
163844957391.49x

A100 vs H100 with Each Version

SetupTFLOPS% of Peak
A100 + FA222572%
H100 + FA250851%
H100 + FA374075%

FlashAttention-2 underutilizes H100 because it wasn't designed for Hopper's architecture. FlashAttention-3 achieves similar utilization on H100 as FA2 does on A100.

FP8 Performance

With FP8, FlashAttention-3 on H100 achieves even higher throughput:

PrecisionFA3 TFLOPSvs FP16 FA3
FP167401.0x
BF167401.0x
FP8 E4M31,2001.62x

Note: FP8 numbers are for inference. Training with FP8 requires master weights in higher precision.

Backward Compatibility

What Stays the Same

FlashAttention-3 maintains API compatibility with FlashAttention-2:

# Works with both FA2 and FA3
from flash_attn import flash_attn_func

output = flash_attn_func(
    q, k, v,
    causal=True,
    softmax_scale=None,
    dropout_p=0.0
)

What Changes

  1. Import path: Same (flash_attn)
  2. Function signatures: Same
  3. Output format: Same [batch, seq, heads, dim]
  4. Numerical results: Identical to FA2 (within FP precision)

The upgrade is typically a drop-in replacement—just install the new version.

When to Upgrade to FlashAttention-3

Upgrade If:

  • Using H100/H200 GPUs: FA3 is specifically optimized for Hopper
  • Need FP8 attention: Only available in FA3
  • Training at scale: 1.4-1.5x speedup compounds significantly
  • Long sequences: Benefits increase with sequence length

Stay with FlashAttention-2 If:

  • Using Ampere or older: FA3 requires Hopper (sm_90)
  • CUDA < 12.3: FA3 needs newer CUDA toolkit
  • Stability concerns: FA2 is more battle-tested
  • Non-standard patterns: FA3 may have reduced feature set initially

Hardware Requirements

VersionMin Compute CapabilityMin CUDASupported GPUs
FA27.0 (Volta)11.6V100, T4, A100, H100
FA39.0 (Hopper)12.3H100, H200 only

Installation

FlashAttention-2

# From prebuilt wheel (recommended)
pip install flash-attn

# Or find exact wheel at flashattn.dev
pip install https://github.com/.../flash_attn-2.x.x+cuXXX-cp3XX-linux_x86_64.whl

FlashAttention-3

# FA3 requires Hopper GPU and CUDA 12.3+
pip install flash-attn --pre  # Latest including FA3

# Or specify version
pip install flash-attn>=3.0.0

Checking Your Version

import flash_attn
print(flash_attn.__version__)

# Check Hopper support
import torch
print(f"CUDA compute capability: {torch.cuda.get_device_capability()}")
# Needs (9, 0) for FA3

Migration Considerations

Code Changes

Most code requires no changes:

# This works with both FA2 and FA3
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func

# Standard usage
output = flash_attn_func(q, k, v, causal=True)

# Packed QKV
qkv = torch.stack([q, k, v], dim=2)  # [batch, seq, 3, heads, dim]
output = flash_attn_qkvpacked_func(qkv, causal=True)

FP8 Migration

To leverage FP8 in FA3:

# Step 1: Quantize inputs (use per-tensor or per-channel scaling)
from torch.ao.quantization import quantize_dynamic

q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)

# Step 2: Run attention
output_fp8 = flash_attn_func(q_fp8, k_fp8, v_fp8, causal=True)

# Step 3: Upcast output for downstream ops
output = output_fp8.to(torch.float16)

Performance Tuning

FA3-specific optimizations:

# Larger head dimensions work better with FA3's larger tiles
# FA2 optimal: d=64
# FA3 optimal: d=128

# If using d=64, consider padding to d=128 for FA3
if q.shape[-1] == 64 and is_hopper:
    q = F.pad(q, (0, 64))
    k = F.pad(k, (0, 64))
    v = F.pad(v, (0, 64))
    output = flash_attn_func(q, k, v)[..., :64]

Real-World Impact

Training Time Reduction

For a 7B parameter model training on 8x H100:

ConfigurationTime/StepDaily Tokens
FA2 on H1001.8s3.8T
FA3 on H1001.3s5.3T
FA3 + FP80.95s7.2T

Cost Implications

At $3.50/hour for H100 spot instances:

SetupTokens/$Cost for 1T tokens
FA21.09T$917
FA31.51T$662
FA3 + FP82.06T$485

FA3 with FP8 reduces training costs by nearly 50% compared to FA2.

Future Outlook

FlashAttention continues active development:

  • FA3.x: Ongoing optimizations for Hopper
  • Blackwell support: Expected for next-gen GPUs (B100, B200)
  • Broader FP8: Training support improvements
  • New attention patterns: Ring attention, sparse patterns

References

  1. Shah, J., et al. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." arXiv:2407.08608

  2. Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." arXiv:2307.08691

  3. NVIDIA. (2025). "Hopper Architecture White Paper." NVIDIA Documentation

  4. NVIDIA. (2025). "CUDA C++ Programming Guide: Warp Group MMA." NVIDIA Documentation

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel