Back to all articles
GPU Optimization

Gradient Checkpointing Explained: Trade Compute for Memory

Deep dive into gradient checkpointing for training large models. Learn how it works, when to use it, and implementation details with PyTorch code examples.

Flash Attention TeamJanuary 8, 20268 min read
gradient checkpointingactivation checkpointingmemory optimizationPyTorchdeep learning training

Gradient checkpointing is one of the most effective techniques for training large models on limited GPU memory. This guide explains exactly how it works and when you should use it.

The Problem: Activation Memory

During neural network training, the forward pass computes and stores intermediate values (activations) needed for the backward pass:

# Forward pass
h1 = layer1(x)      # Store h1 for backward
h2 = layer2(h1)     # Store h2 for backward
h3 = layer3(h2)     # Store h3 for backward
loss = criterion(h3, target)

# Backward pass (needs stored activations)
dh3 = grad(loss, h3)           # Uses h3
dh2 = grad(h3, h2) * dh3       # Uses h2
dh1 = grad(h2, h1) * dh2       # Uses h1

For a 7B parameter transformer with 32 layers and batch size 8:

ComponentMemory Required
Activations per layer~500 MB
Total (32 layers)~16 GB
With attention matrices~25 GB

This activation memory often exceeds model parameter memory, becoming the primary bottleneck.

How Gradient Checkpointing Works

Instead of storing all activations, gradient checkpointing:

  1. Stores only selected activations (checkpoints)
  2. Recomputes the rest during backward pass
Without Checkpointing:
Forward:  [A1] → [A2] → [A3] → [A4] → Loss
          ↓      ↓      ↓      ↓
          Store  Store  Store  Store

With Checkpointing (every 2 layers):
Forward:  [A1] → [A2] → [A3] → [A4] → Loss
          ↓             ↓
          Store         Store (checkpoints)

Backward: Recompute A2 from A1, Recompute A4 from A3

Mathematical Analysis

For a network with n layers:

StrategyActivation MemoryForward Compute
No checkpointingO(n)1x
Checkpoint every √n layersO(√n)~1.3x
Checkpoint every layerO(1)2x

The optimal strategy checkpoints every √n layers, reducing memory from O(n) to O(√n) with only ~33% compute overhead.

PyTorch Implementation

Basic Usage

from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        # Without checkpointing: stores all intermediate activations
        x = x + self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ffn(self.norm2(x))
        return x

class CheckpointedTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.block = TransformerBlock(dim, num_heads)

    def forward(self, x):
        # With checkpointing: only stores input, recomputes during backward
        return checkpoint(self.block, x, use_reentrant=False)

Checkpointing Sequential Layers

For models with many sequential layers:

from torch.utils.checkpoint import checkpoint_sequential

class DeepModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            *[TransformerBlock(512, 8) for _ in range(24)]
        )

    def forward(self, x):
        # Checkpoint every 4 layers (creates 6 segments)
        segments = 6
        return checkpoint_sequential(self.layers, segments, x, use_reentrant=False)

The use_reentrant Parameter

Always use use_reentrant=False (the new default):

# OLD (use_reentrant=True) - has issues with:
# - Requires inputs to require gradients
# - Issues with hooks
# - Memory leaks in some cases

# NEW (use_reentrant=False) - recommended:
# - Works with any input
# - Compatible with hooks
# - More predictable behavior

checkpoint(fn, x, use_reentrant=False)

Hugging Face Integration

Enabling in Transformers

from transformers import AutoModelForCausalLM

# Method 1: During loading
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.bfloat16,
)
model.gradient_checkpointing_enable()

# Method 2: In TrainingArguments
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./output",
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

Model-Specific Behavior

Different models implement checkpointing at different granularities:

ModelCheckpoint GranularityMemory Reduction
LLaMAPer transformer block~60%
GPT-2Per transformer block~60%
T5Encoder and decoder separately~50%
BERTPer transformer layer~60%

Selective Checkpointing

Not all layers benefit equally from checkpointing. Attention layers store the most activations:

class SelectiveCheckpointModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_layers)
        ])
        # Checkpoint every other layer
        self.checkpoint_layers = set(range(0, config.num_layers, 2))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            if i in self.checkpoint_layers:
                x = checkpoint(layer, x, use_reentrant=False)
            else:
                x = layer(x)
        return x

Memory-Optimal Strategy

Based on the paper "Training Deep Nets with Sublinear Memory Cost" (Chen et al., 2016):

import math

def optimal_checkpoint_schedule(num_layers):
    """
    Returns which layers to checkpoint for O(√n) memory.
    """
    segment_size = int(math.sqrt(num_layers))
    checkpoints = list(range(0, num_layers, segment_size))
    return checkpoints

# For 32 layers: checkpoint at [0, 5, 10, 15, 20, 25, 30]
# Memory: O(√32) ≈ O(6) instead of O(32)

Performance Impact

Benchmarks

Measured on A100 80GB with LLaMA-7B:

ConfigurationMemoryTraining SpeedSpeedup
No checkpointing45 GB1.0xbaseline
Checkpoint all18 GB0.72x-28%
Checkpoint selective24 GB0.85x-15%
Checkpoint + Flash Attn12 GB0.95x-5%

When Checkpointing Helps vs Hurts

Use checkpointing when:

  • GPU memory is the bottleneck
  • You want to train larger batch sizes
  • Training longer sequences
  • Model doesn't fit otherwise

Avoid checkpointing when:

  • You have spare GPU memory
  • Training time is more important than memory
  • Already using other memory optimizations effectively

Combining with Other Techniques

Checkpointing + Flash Attention

Flash Attention already reduces attention memory. Combining with checkpointing:

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="flash_attention_2",  # Reduces attention memory
    torch_dtype=torch.bfloat16,
)
model.gradient_checkpointing_enable()  # Reduces activation memory

# Flash Attention: O(N) instead of O(N²) for attention
# Checkpointing: Recomputes all other activations
# Combined: Maximum memory efficiency

Checkpointing + Mixed Precision

from torch.cuda.amp import autocast

class CheckpointedWithAMP(nn.Module):
    def forward(self, x):
        def custom_forward(x):
            with autocast():
                return self.layers(x)

        return checkpoint(custom_forward, x, use_reentrant=False)

Checkpointing + Gradient Accumulation

training_args = TrainingArguments(
    per_device_train_batch_size=2,      # Small batch for memory
    gradient_accumulation_steps=16,     # Effective batch: 32
    gradient_checkpointing=True,        # Further memory savings
)

Common Issues and Solutions

Issue 1: Slowed Training

Checkpointing recomputes forward pass during backward, roughly doubling compute for checkpointed sections.

Solution: Checkpoint selectively, not every layer:

# Instead of checkpointing every layer
for layer in self.layers:
    x = checkpoint(layer, x)

# Checkpoint every N layers
for i, layer in enumerate(self.layers):
    if i % 4 == 0:  # Every 4th layer
        x = checkpoint(layer, x)
    else:
        x = layer(x)

Issue 2: Non-Deterministic Gradients

Some operations (dropout) behave differently when recomputed.

Solution: Use preserve_rng_state=True:

from torch.utils.checkpoint import checkpoint

output = checkpoint(
    fn,
    input,
    use_reentrant=False,
    preserve_rng_state=True,  # Preserve dropout masks
)

Issue 3: Hooks Not Working

With use_reentrant=True, hooks on checkpointed modules may not fire.

Solution: Use use_reentrant=False:

# This ensures hooks work correctly
output = checkpoint(fn, x, use_reentrant=False)

Issue 4: CUDA Graphs Incompatibility

Checkpointing is incompatible with CUDA graphs due to dynamic recomputation.

Solution: Disable checkpointing when using CUDA graphs, or skip CUDA graphs:

# For inference with CUDA graphs: disable checkpointing
model.gradient_checkpointing_disable()

Memory Estimation

Estimate memory savings before implementing:

def estimate_checkpoint_savings(model, batch_size, seq_len, num_checkpoints):
    """
    Estimate memory savings from checkpointing.
    """
    # Activation memory without checkpointing
    num_layers = len(list(model.modules()))
    bytes_per_activation = 2  # FP16
    hidden_dim = model.config.hidden_size

    # Rough estimate of activations per layer
    activations_per_layer = batch_size * seq_len * hidden_dim * 4  # Multiple intermediates
    total_without = num_layers * activations_per_layer * bytes_per_activation

    # With checkpointing
    segment_size = num_layers // num_checkpoints
    total_with = segment_size * activations_per_layer * bytes_per_activation

    savings = (total_without - total_with) / 1e9
    print(f"Estimated savings: {savings:.2f} GB")
    return savings

References

  1. Chen, T., et al. (2016). "Training Deep Nets with Sublinear Memory Cost." arXiv:1604.06174

  2. Griewank, A., & Walther, A. (2000). "Algorithm 799: Revolve: An Implementation of Checkpointing for the Reverse or Adjoint Mode of Computational Differentiation." ACM Transactions on Mathematical Software.

  3. PyTorch Documentation. "torch.utils.checkpoint." PyTorch Docs

  4. Korthikanti, V., et al. (2022). "Reducing Activation Recomputation in Large Transformer Models." arXiv:2205.05198

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel