Gradient checkpointing is one of the most effective techniques for training large models on limited GPU memory. This guide explains exactly how it works and when you should use it.
The Problem: Activation Memory
During neural network training, the forward pass computes and stores intermediate values (activations) needed for the backward pass:
# Forward pass
h1 = layer1(x) # Store h1 for backward
h2 = layer2(h1) # Store h2 for backward
h3 = layer3(h2) # Store h3 for backward
loss = criterion(h3, target)
# Backward pass (needs stored activations)
dh3 = grad(loss, h3) # Uses h3
dh2 = grad(h3, h2) * dh3 # Uses h2
dh1 = grad(h2, h1) * dh2 # Uses h1
For a 7B parameter transformer with 32 layers and batch size 8:
| Component | Memory Required |
|---|---|
| Activations per layer | ~500 MB |
| Total (32 layers) | ~16 GB |
| With attention matrices | ~25 GB |
This activation memory often exceeds model parameter memory, becoming the primary bottleneck.
How Gradient Checkpointing Works
Instead of storing all activations, gradient checkpointing:
- Stores only selected activations (checkpoints)
- Recomputes the rest during backward pass
Without Checkpointing:
Forward: [A1] → [A2] → [A3] → [A4] → Loss
↓ ↓ ↓ ↓
Store Store Store Store
With Checkpointing (every 2 layers):
Forward: [A1] → [A2] → [A3] → [A4] → Loss
↓ ↓
Store Store (checkpoints)
Backward: Recompute A2 from A1, Recompute A4 from A3
Mathematical Analysis
For a network with n layers:
| Strategy | Activation Memory | Forward Compute |
|---|---|---|
| No checkpointing | O(n) | 1x |
| Checkpoint every √n layers | O(√n) | ~1.3x |
| Checkpoint every layer | O(1) | 2x |
The optimal strategy checkpoints every √n layers, reducing memory from O(n) to O(√n) with only ~33% compute overhead.
PyTorch Implementation
Basic Usage
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.attention = nn.MultiheadAttention(dim, num_heads)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
# Without checkpointing: stores all intermediate activations
x = x + self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ffn(self.norm2(x))
return x
class CheckpointedTransformerBlock(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.block = TransformerBlock(dim, num_heads)
def forward(self, x):
# With checkpointing: only stores input, recomputes during backward
return checkpoint(self.block, x, use_reentrant=False)
Checkpointing Sequential Layers
For models with many sequential layers:
from torch.utils.checkpoint import checkpoint_sequential
class DeepModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
*[TransformerBlock(512, 8) for _ in range(24)]
)
def forward(self, x):
# Checkpoint every 4 layers (creates 6 segments)
segments = 6
return checkpoint_sequential(self.layers, segments, x, use_reentrant=False)
The use_reentrant Parameter
Always use use_reentrant=False (the new default):
# OLD (use_reentrant=True) - has issues with:
# - Requires inputs to require gradients
# - Issues with hooks
# - Memory leaks in some cases
# NEW (use_reentrant=False) - recommended:
# - Works with any input
# - Compatible with hooks
# - More predictable behavior
checkpoint(fn, x, use_reentrant=False)
Hugging Face Integration
Enabling in Transformers
from transformers import AutoModelForCausalLM
# Method 1: During loading
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.bfloat16,
)
model.gradient_checkpointing_enable()
# Method 2: In TrainingArguments
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./output",
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
)
Model-Specific Behavior
Different models implement checkpointing at different granularities:
| Model | Checkpoint Granularity | Memory Reduction |
|---|---|---|
| LLaMA | Per transformer block | ~60% |
| GPT-2 | Per transformer block | ~60% |
| T5 | Encoder and decoder separately | ~50% |
| BERT | Per transformer layer | ~60% |
Selective Checkpointing
Not all layers benefit equally from checkpointing. Attention layers store the most activations:
class SelectiveCheckpointModel(nn.Module):
def __init__(self, config):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.num_layers)
])
# Checkpoint every other layer
self.checkpoint_layers = set(range(0, config.num_layers, 2))
def forward(self, x):
for i, layer in enumerate(self.layers):
if i in self.checkpoint_layers:
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return x
Memory-Optimal Strategy
Based on the paper "Training Deep Nets with Sublinear Memory Cost" (Chen et al., 2016):
import math
def optimal_checkpoint_schedule(num_layers):
"""
Returns which layers to checkpoint for O(√n) memory.
"""
segment_size = int(math.sqrt(num_layers))
checkpoints = list(range(0, num_layers, segment_size))
return checkpoints
# For 32 layers: checkpoint at [0, 5, 10, 15, 20, 25, 30]
# Memory: O(√32) ≈ O(6) instead of O(32)
Performance Impact
Benchmarks
Measured on A100 80GB with LLaMA-7B:
| Configuration | Memory | Training Speed | Speedup |
|---|---|---|---|
| No checkpointing | 45 GB | 1.0x | baseline |
| Checkpoint all | 18 GB | 0.72x | -28% |
| Checkpoint selective | 24 GB | 0.85x | -15% |
| Checkpoint + Flash Attn | 12 GB | 0.95x | -5% |
When Checkpointing Helps vs Hurts
Use checkpointing when:
- GPU memory is the bottleneck
- You want to train larger batch sizes
- Training longer sequences
- Model doesn't fit otherwise
Avoid checkpointing when:
- You have spare GPU memory
- Training time is more important than memory
- Already using other memory optimizations effectively
Combining with Other Techniques
Checkpointing + Flash Attention
Flash Attention already reduces attention memory. Combining with checkpointing:
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2", # Reduces attention memory
torch_dtype=torch.bfloat16,
)
model.gradient_checkpointing_enable() # Reduces activation memory
# Flash Attention: O(N) instead of O(N²) for attention
# Checkpointing: Recomputes all other activations
# Combined: Maximum memory efficiency
Checkpointing + Mixed Precision
from torch.cuda.amp import autocast
class CheckpointedWithAMP(nn.Module):
def forward(self, x):
def custom_forward(x):
with autocast():
return self.layers(x)
return checkpoint(custom_forward, x, use_reentrant=False)
Checkpointing + Gradient Accumulation
training_args = TrainingArguments(
per_device_train_batch_size=2, # Small batch for memory
gradient_accumulation_steps=16, # Effective batch: 32
gradient_checkpointing=True, # Further memory savings
)
Common Issues and Solutions
Issue 1: Slowed Training
Checkpointing recomputes forward pass during backward, roughly doubling compute for checkpointed sections.
Solution: Checkpoint selectively, not every layer:
# Instead of checkpointing every layer
for layer in self.layers:
x = checkpoint(layer, x)
# Checkpoint every N layers
for i, layer in enumerate(self.layers):
if i % 4 == 0: # Every 4th layer
x = checkpoint(layer, x)
else:
x = layer(x)
Issue 2: Non-Deterministic Gradients
Some operations (dropout) behave differently when recomputed.
Solution: Use preserve_rng_state=True:
from torch.utils.checkpoint import checkpoint
output = checkpoint(
fn,
input,
use_reentrant=False,
preserve_rng_state=True, # Preserve dropout masks
)
Issue 3: Hooks Not Working
With use_reentrant=True, hooks on checkpointed modules may not fire.
Solution: Use use_reentrant=False:
# This ensures hooks work correctly
output = checkpoint(fn, x, use_reentrant=False)
Issue 4: CUDA Graphs Incompatibility
Checkpointing is incompatible with CUDA graphs due to dynamic recomputation.
Solution: Disable checkpointing when using CUDA graphs, or skip CUDA graphs:
# For inference with CUDA graphs: disable checkpointing
model.gradient_checkpointing_disable()
Memory Estimation
Estimate memory savings before implementing:
def estimate_checkpoint_savings(model, batch_size, seq_len, num_checkpoints):
"""
Estimate memory savings from checkpointing.
"""
# Activation memory without checkpointing
num_layers = len(list(model.modules()))
bytes_per_activation = 2 # FP16
hidden_dim = model.config.hidden_size
# Rough estimate of activations per layer
activations_per_layer = batch_size * seq_len * hidden_dim * 4 # Multiple intermediates
total_without = num_layers * activations_per_layer * bytes_per_activation
# With checkpointing
segment_size = num_layers // num_checkpoints
total_with = segment_size * activations_per_layer * bytes_per_activation
savings = (total_without - total_with) / 1e9
print(f"Estimated savings: {savings:.2f} GB")
return savings
References
-
Chen, T., et al. (2016). "Training Deep Nets with Sublinear Memory Cost." arXiv:1604.06174
-
Griewank, A., & Walther, A. (2000). "Algorithm 799: Revolve: An Implementation of Checkpointing for the Reverse or Adjoint Mode of Computational Differentiation." ACM Transactions on Mathematical Software.
-
PyTorch Documentation. "torch.utils.checkpoint." PyTorch Docs
-
Korthikanti, V., et al. (2022). "Reducing Activation Recomputation in Large Transformer Models." arXiv:2205.05198