Running out of GPU memory is the most common blocker when training large models. This comprehensive guide explains exactly where your memory goes and provides proven techniques to fit larger models on your hardware.
Understanding GPU Memory
Memory Hierarchy
Modern GPUs have a tiered memory system:
| Memory Type | Size | Bandwidth | Latency | Scope |
|---|---|---|---|---|
| Registers | ~256KB/SM | ~20 TB/s | 0 cycles | Per thread |
| L1/Shared | 128KB/SM | ~19 TB/s | ~20 cycles | Per SM |
| L2 Cache | 40-60MB | ~5 TB/s | ~200 cycles | Global |
| HBM (Main) | 24-80GB | 1-3 TB/s | ~400 cycles | Global |
For training, we primarily interact with HBM (High Bandwidth Memory), which is what you see in nvidia-smi.
Where Training Memory Goes
When training a model, GPU memory is consumed by four main components:
Total Memory = Model Parameters + Optimizer States + Gradients + Activations
Let's break down each component for a 7B parameter model:
1. Model Parameters
The model weights themselves:
| Precision | Bytes/Param | 7B Model Size |
|---|---|---|
| FP32 | 4 bytes | 28 GB |
| FP16/BF16 | 2 bytes | 14 GB |
| INT8 | 1 byte | 7 GB |
| INT4 | 0.5 bytes | 3.5 GB |
2. Optimizer States
AdamW stores two additional values per parameter:
# AdamW optimizer state
m = β₁ * m + (1 - β₁) * gradient # First moment (momentum)
v = β₂ * v + (1 - β₂) * gradient² # Second moment (variance)
Memory for Adam with FP32 states:
- Momentum (m): 4 bytes/param
- Variance (v): 4 bytes/param
- Total: 8 bytes/param
For 7B model: 7B × 8 = 56 GB for optimizer states alone
3. Gradients
Same size as parameters in the precision used:
| Precision | 7B Model Gradients |
|---|---|
| FP32 | 28 GB |
| FP16 | 14 GB |
4. Activations
Activations are the intermediate values saved for the backward pass. This is where memory explodes with batch size and sequence length:
# For a transformer layer
activation_memory ≈ batch_size × seq_len × hidden_dim × num_layers × bytes_per_value
For a 7B model (32 layers, hidden=4096, seq=2048):
- Per token per layer: 4096 × 2 bytes = 8 KB
- Per sequence per layer: 2048 × 8 KB = 16 MB
- Per batch item: 32 × 16 MB = 512 MB
- Batch size 8: 4 GB for activations
Attention activations are particularly expensive:
attention_activations = batch × heads × seq × seq × 2 bytes
# For seq=2048, heads=32: batch × 32 × 2048 × 2048 × 2 = batch × 256 MB
Total Memory Example: 7B Model
Component | FP32 Training | Mixed Precision
--------------------|---------------|----------------
Parameters | 28 GB | 14 GB (FP16)
Optimizer States | 56 GB | 28 GB (FP16) or 56 GB (FP32 master)
Gradients | 28 GB | 14 GB
Activations | 20+ GB | 20+ GB
--------------------|---------------|----------------
Total | 132+ GB | 76+ GB
This is why you can't train a 7B model on a single 24GB consumer GPU without optimization techniques.
Diagnosing Memory Issues
Reading CUDA Memory Stats
import torch
def print_memory_stats():
"""Print detailed GPU memory statistics"""
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
# Reset peak stats before training
torch.cuda.reset_peak_memory_stats()
# After training step
print_memory_stats()
Memory Profiling
For detailed analysis, use PyTorch's memory profiler:
import torch
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
profile_memory=True,
record_shapes=True,
with_stack=True,
) as prof:
# Your training code
loss = model(inputs)
loss.backward()
# Print memory timeline
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=20))
# Export for TensorBoard
prof.export_chrome_trace("memory_trace.json")
Finding Memory Leaks
Common sources of memory leaks:
# LEAK: Accumulating gradients across iterations
for batch in dataloader:
loss = model(batch)
total_loss += loss # Keeps computation graph!
# FIX: Detach or use .item()
total_loss += loss.item()
# LEAK: Storing tensors in lists
all_outputs = []
for batch in dataloader:
output = model(batch)
all_outputs.append(output) # Fills memory!
# FIX: Move to CPU or detach
all_outputs.append(output.detach().cpu())
# LEAK: Not clearing cache between experiments
# FIX: Clear cache between runs
torch.cuda.empty_cache()
Memory Optimization Techniques
1. Mixed Precision Training
Mixed precision uses FP16 for most operations while keeping critical values in FP32:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in FP16
with autocast():
outputs = model(batch)
loss = criterion(outputs)
# Backward pass with scaling
scaler.scale(loss).backward()
# Optimizer step with unscaling
scaler.step(optimizer)
scaler.update()
Memory savings: ~50% reduction in model and gradient memory
BF16 vs FP16:
| Feature | FP16 | BF16 |
|---|---|---|
| Range | Limited (needs scaling) | Same as FP32 |
| Precision | Higher | Lower |
| Hardware | All GPUs | Ampere+ |
| Stability | Needs GradScaler | No scaling needed |
# BF16 is simpler (no scaler needed)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
outputs = model(batch)
loss.backward()
2. Gradient Checkpointing
Trade compute for memory by not storing all activations:
from torch.utils.checkpoint import checkpoint
class CheckpointedTransformerBlock(nn.Module):
def forward(self, x):
# Without checkpointing: stores all intermediate activations
# With checkpointing: recomputes during backward pass
return checkpoint(self._forward, x, use_reentrant=False)
def _forward(self, x):
x = self.attention(x)
x = self.feedforward(x)
return x
For Hugging Face models:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.bfloat16,
)
model.gradient_checkpointing_enable()
# Or in TrainingArguments
training_args = TrainingArguments(
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
)
Memory savings: 60-70% activation memory reduction Compute cost: ~30% slower training
3. Gradient Accumulation
Simulate larger batch sizes without memory increase:
accumulation_steps = 4
effective_batch_size = batch_size * accumulation_steps
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
With Hugging Face:
training_args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
# Effective batch size: 4 * 8 = 32
)
4. Flash Attention
Replace standard attention with memory-efficient Flash Attention:
# Standard attention memory: O(seq_len²)
# Flash Attention memory: O(seq_len)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
Memory savings: 5-20x for attention, critical for long sequences Speed: 2-4x faster attention computation
5. Optimizer State Optimization
8-bit Optimizers
bitsandbytes provides memory-efficient optimizers:
import bitsandbytes as bnb
# Standard Adam: 8 bytes/param for states
# 8-bit Adam: 1 byte/param for states
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999),
)
Memory savings: 75% reduction in optimizer states
Paged Optimizers
For large models that don't fit entirely in GPU memory:
optimizer = bnb.optim.PagedAdamW32bit(
model.parameters(),
lr=1e-4,
)
6. Model Quantization for Training (QLoRA)
Train with quantized base model:
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=bnb_config,
device_map="auto",
)
# 7B model: 28GB → 3.5GB for weights
Memory savings: 75-87.5% reduction in model weights
7. CPU Offloading
Move optimizer states or parameters to CPU:
from accelerate import Accelerator
accelerator = Accelerator(
mixed_precision="bf16",
cpu_offload=True, # Offload optimizer states
)
# Or with DeepSpeed ZeRO-3
ds_config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "cpu"},
}
}
Trade-off: Slower training due to CPU-GPU transfer
Memory Optimization Cheatsheet
Quick Fixes (No Code Changes)
| Technique | Memory Saved | Speed Impact |
|---|---|---|
| Reduce batch size | Linear | Slower convergence |
| Reduce sequence length | Quadratic | Less context |
| Use BF16/FP16 | 50% | Minimal |
| Clear cache | Variable | None |
Code Changes (Moderate Effort)
| Technique | Memory Saved | Speed Impact |
|---|---|---|
| Gradient checkpointing | 60-70% activations | 30% slower |
| Flash Attention | 5-20x attention | 2-4x faster |
| Gradient accumulation | Proportional to steps | Slightly slower |
| 8-bit optimizer | 75% optimizer states | Minimal |
Major Changes (High Effort)
| Technique | Memory Saved | Speed Impact |
|---|---|---|
| QLoRA | 75%+ model | Similar |
| DeepSpeed ZeRO | Distributed | Communication overhead |
| CPU offloading | Unlimited | 2-5x slower |
Putting It All Together
Here's a complete memory-optimized training setup:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
# 1. Quantize model (QLoRA)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=bnb_config,
attn_implementation="flash_attention_2", # 2. Flash Attention
device_map="auto",
)
# 3. Prepare for training
model = prepare_model_for_kbit_training(model)
# 4. Add LoRA adapters (only train small matrices)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
# 5. Training arguments with memory optimizations
training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # Effective batch size: 16
gradient_checkpointing=True, # 6. Gradient checkpointing
gradient_checkpointing_kwargs={"use_reentrant": False},
bf16=True, # 7. Mixed precision
optim="paged_adamw_8bit", # 8. 8-bit optimizer
learning_rate=2e-4,
warmup_ratio=0.03,
num_train_epochs=3,
logging_steps=10,
save_strategy="epoch",
)
# Train
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
dataset_text_field="text",
max_seq_length=2048,
)
trainer.train()
Memory requirement: ~6GB for 7B model (down from 130+ GB)
Debugging OOM Errors
Step-by-Step OOM Debugging
# 1. Find where OOM occurs
torch.cuda.memory._record_memory_history(max_entries=100000)
try:
train_step()
except RuntimeError as e:
if "out of memory" in str(e):
torch.cuda.memory._dump_snapshot("oom_snapshot.pickle")
raise
# 2. Analyze the snapshot
# Use torch.cuda.memory._load_snapshot("oom_snapshot.pickle")
# to visualize in memory_viz tool
# 3. Common fixes
if oom_in_forward:
# Reduce batch size or sequence length
# Enable gradient checkpointing
# Use Flash Attention
if oom_in_backward:
# Enable gradient checkpointing
# Use gradient accumulation with smaller batch
# Check for memory leaks (retained computation graphs)
if oom_in_optimizer:
# Use 8-bit optimizer
# Enable optimizer state offloading
Memory Monitoring During Training
class MemoryMonitorCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f"Step {state.global_step}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
if allocated > 0.9 * torch.cuda.get_device_properties(0).total_memory / 1e9:
print("WARNING: Approaching memory limit!")
Hardware-Specific Tips
Consumer GPUs (RTX 3090, 4090)
- Max 24GB VRAM
- Use QLoRA for 7B+ models
- Flash Attention is critical
- Consider gradient checkpointing always
Professional GPUs (A100)
- 40GB or 80GB options
- Can often skip quantization for 7B models
- Full fine-tuning possible for smaller models
- Multi-GPU beneficial for 70B+ models
Memory per GPU Generation
| GPU | VRAM | Bandwidth | Best For |
|---|---|---|---|
| RTX 3080 | 10GB | 760 GB/s | Inference, small training |
| RTX 3090 | 24GB | 936 GB/s | QLoRA 7B-13B |
| RTX 4090 | 24GB | 1008 GB/s | QLoRA 7B-13B, faster |
| A100 40GB | 40GB | 1555 GB/s | LoRA 7B-13B, QLoRA 70B |
| A100 80GB | 80GB | 2039 GB/s | Full FT 7B, LoRA 70B |
| H100 | 80GB | 3350 GB/s | Everything, faster |
References
-
Rajbhandari, S., et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models." arXiv:1910.02054
-
Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." arXiv:2307.08691
-
Dettmers, T., et al. (2023). "QLoRA: Efficient Finetuning of Quantized LLMs." arXiv:2305.14314
-
Micikevicius, P., et al. (2018). "Mixed Precision Training." arXiv:1710.03740
-
Chen, T., et al. (2016). "Training Deep Nets with Sublinear Memory Cost." arXiv:1604.06174