Back to all articles
GPU Optimization

GPU Memory Optimization for Deep Learning: A Complete Guide

Master GPU memory optimization for training large models. Covers memory anatomy, OOM debugging, gradient checkpointing, mixed precision, and advanced techniques with practical PyTorch examples.

Flash Attention TeamJanuary 8, 202611 min read
GPU memoryCUDA memoryOOM errormemory optimizationPyTorch trainingdeep learning

Running out of GPU memory is the most common blocker when training large models. This comprehensive guide explains exactly where your memory goes and provides proven techniques to fit larger models on your hardware.

Understanding GPU Memory

Memory Hierarchy

Modern GPUs have a tiered memory system:

Memory TypeSizeBandwidthLatencyScope
Registers~256KB/SM~20 TB/s0 cyclesPer thread
L1/Shared128KB/SM~19 TB/s~20 cyclesPer SM
L2 Cache40-60MB~5 TB/s~200 cyclesGlobal
HBM (Main)24-80GB1-3 TB/s~400 cyclesGlobal

For training, we primarily interact with HBM (High Bandwidth Memory), which is what you see in nvidia-smi.

Where Training Memory Goes

When training a model, GPU memory is consumed by four main components:

Total Memory = Model Parameters + Optimizer States + Gradients + Activations

Let's break down each component for a 7B parameter model:

1. Model Parameters

The model weights themselves:

PrecisionBytes/Param7B Model Size
FP324 bytes28 GB
FP16/BF162 bytes14 GB
INT81 byte7 GB
INT40.5 bytes3.5 GB

2. Optimizer States

AdamW stores two additional values per parameter:

# AdamW optimizer state
m = β₁ * m + (1 - β₁) * gradient      # First moment (momentum)
v = β₂ * v + (1 - β₂) * gradient²     # Second moment (variance)

Memory for Adam with FP32 states:

  • Momentum (m): 4 bytes/param
  • Variance (v): 4 bytes/param
  • Total: 8 bytes/param

For 7B model: 7B × 8 = 56 GB for optimizer states alone

3. Gradients

Same size as parameters in the precision used:

Precision7B Model Gradients
FP3228 GB
FP1614 GB

4. Activations

Activations are the intermediate values saved for the backward pass. This is where memory explodes with batch size and sequence length:

# For a transformer layer
activation_memory ≈ batch_size × seq_len × hidden_dim × num_layers × bytes_per_value

For a 7B model (32 layers, hidden=4096, seq=2048):

  • Per token per layer: 4096 × 2 bytes = 8 KB
  • Per sequence per layer: 2048 × 8 KB = 16 MB
  • Per batch item: 32 × 16 MB = 512 MB
  • Batch size 8: 4 GB for activations

Attention activations are particularly expensive:

attention_activations = batch × heads × seq × seq × 2 bytes
# For seq=2048, heads=32: batch × 32 × 2048 × 2048 × 2 = batch × 256 MB

Total Memory Example: 7B Model

Component           | FP32 Training | Mixed Precision
--------------------|---------------|----------------
Parameters          | 28 GB         | 14 GB (FP16)
Optimizer States    | 56 GB         | 28 GB (FP16) or 56 GB (FP32 master)
Gradients           | 28 GB         | 14 GB
Activations         | 20+ GB        | 20+ GB
--------------------|---------------|----------------
Total               | 132+ GB       | 76+ GB

This is why you can't train a 7B model on a single 24GB consumer GPU without optimization techniques.

Diagnosing Memory Issues

Reading CUDA Memory Stats

import torch

def print_memory_stats():
    """Print detailed GPU memory statistics"""
    print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
    print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

# Reset peak stats before training
torch.cuda.reset_peak_memory_stats()

# After training step
print_memory_stats()

Memory Profiling

For detailed analysis, use PyTorch's memory profiler:

import torch
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    profile_memory=True,
    record_shapes=True,
    with_stack=True,
) as prof:
    # Your training code
    loss = model(inputs)
    loss.backward()

# Print memory timeline
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=20))

# Export for TensorBoard
prof.export_chrome_trace("memory_trace.json")

Finding Memory Leaks

Common sources of memory leaks:

# LEAK: Accumulating gradients across iterations
for batch in dataloader:
    loss = model(batch)
    total_loss += loss  # Keeps computation graph!

# FIX: Detach or use .item()
total_loss += loss.item()

# LEAK: Storing tensors in lists
all_outputs = []
for batch in dataloader:
    output = model(batch)
    all_outputs.append(output)  # Fills memory!

# FIX: Move to CPU or detach
all_outputs.append(output.detach().cpu())

# LEAK: Not clearing cache between experiments
# FIX: Clear cache between runs
torch.cuda.empty_cache()

Memory Optimization Techniques

1. Mixed Precision Training

Mixed precision uses FP16 for most operations while keeping critical values in FP32:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()

    # Forward pass in FP16
    with autocast():
        outputs = model(batch)
        loss = criterion(outputs)

    # Backward pass with scaling
    scaler.scale(loss).backward()

    # Optimizer step with unscaling
    scaler.step(optimizer)
    scaler.update()

Memory savings: ~50% reduction in model and gradient memory

BF16 vs FP16:

FeatureFP16BF16
RangeLimited (needs scaling)Same as FP32
PrecisionHigherLower
HardwareAll GPUsAmpere+
StabilityNeeds GradScalerNo scaling needed
# BF16 is simpler (no scaler needed)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    outputs = model(batch)
    loss.backward()

2. Gradient Checkpointing

Trade compute for memory by not storing all activations:

from torch.utils.checkpoint import checkpoint

class CheckpointedTransformerBlock(nn.Module):
    def forward(self, x):
        # Without checkpointing: stores all intermediate activations
        # With checkpointing: recomputes during backward pass
        return checkpoint(self._forward, x, use_reentrant=False)

    def _forward(self, x):
        x = self.attention(x)
        x = self.feedforward(x)
        return x

For Hugging Face models:

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.bfloat16,
)
model.gradient_checkpointing_enable()

# Or in TrainingArguments
training_args = TrainingArguments(
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

Memory savings: 60-70% activation memory reduction Compute cost: ~30% slower training

3. Gradient Accumulation

Simulate larger batch sizes without memory increase:

accumulation_steps = 4
effective_batch_size = batch_size * accumulation_steps

optimizer.zero_grad()
for i, batch in enumerate(dataloader):
    loss = model(batch) / accumulation_steps
    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

With Hugging Face:

training_args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    # Effective batch size: 4 * 8 = 32
)

4. Flash Attention

Replace standard attention with memory-efficient Flash Attention:

# Standard attention memory: O(seq_len²)
# Flash Attention memory: O(seq_len)

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

Memory savings: 5-20x for attention, critical for long sequences Speed: 2-4x faster attention computation

5. Optimizer State Optimization

8-bit Optimizers

bitsandbytes provides memory-efficient optimizers:

import bitsandbytes as bnb

# Standard Adam: 8 bytes/param for states
# 8-bit Adam: 1 byte/param for states

optimizer = bnb.optim.Adam8bit(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999),
)

Memory savings: 75% reduction in optimizer states

Paged Optimizers

For large models that don't fit entirely in GPU memory:

optimizer = bnb.optim.PagedAdamW32bit(
    model.parameters(),
    lr=1e-4,
)

6. Model Quantization for Training (QLoRA)

Train with quantized base model:

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=bnb_config,
    device_map="auto",
)

# 7B model: 28GB → 3.5GB for weights

Memory savings: 75-87.5% reduction in model weights

7. CPU Offloading

Move optimizer states or parameters to CPU:

from accelerate import Accelerator

accelerator = Accelerator(
    mixed_precision="bf16",
    cpu_offload=True,  # Offload optimizer states
)

# Or with DeepSpeed ZeRO-3
ds_config = {
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {"device": "cpu"},
        "offload_param": {"device": "cpu"},
    }
}

Trade-off: Slower training due to CPU-GPU transfer

Memory Optimization Cheatsheet

Quick Fixes (No Code Changes)

TechniqueMemory SavedSpeed Impact
Reduce batch sizeLinearSlower convergence
Reduce sequence lengthQuadraticLess context
Use BF16/FP1650%Minimal
Clear cacheVariableNone

Code Changes (Moderate Effort)

TechniqueMemory SavedSpeed Impact
Gradient checkpointing60-70% activations30% slower
Flash Attention5-20x attention2-4x faster
Gradient accumulationProportional to stepsSlightly slower
8-bit optimizer75% optimizer statesMinimal

Major Changes (High Effort)

TechniqueMemory SavedSpeed Impact
QLoRA75%+ modelSimilar
DeepSpeed ZeRODistributedCommunication overhead
CPU offloadingUnlimited2-5x slower

Putting It All Together

Here's a complete memory-optimized training setup:

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer

# 1. Quantize model (QLoRA)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",  # 2. Flash Attention
    device_map="auto",
)

# 3. Prepare for training
model = prepare_model_for_kbit_training(model)

# 4. Add LoRA adapters (only train small matrices)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)

# 5. Training arguments with memory optimizations
training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,  # Effective batch size: 16
    gradient_checkpointing=True,  # 6. Gradient checkpointing
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,  # 7. Mixed precision
    optim="paged_adamw_8bit",  # 8. 8-bit optimizer
    learning_rate=2e-4,
    warmup_ratio=0.03,
    num_train_epochs=3,
    logging_steps=10,
    save_strategy="epoch",
)

# Train
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=2048,
)

trainer.train()

Memory requirement: ~6GB for 7B model (down from 130+ GB)

Debugging OOM Errors

Step-by-Step OOM Debugging

# 1. Find where OOM occurs
torch.cuda.memory._record_memory_history(max_entries=100000)

try:
    train_step()
except RuntimeError as e:
    if "out of memory" in str(e):
        torch.cuda.memory._dump_snapshot("oom_snapshot.pickle")
        raise

# 2. Analyze the snapshot
# Use torch.cuda.memory._load_snapshot("oom_snapshot.pickle")
# to visualize in memory_viz tool

# 3. Common fixes
if oom_in_forward:
    # Reduce batch size or sequence length
    # Enable gradient checkpointing
    # Use Flash Attention

if oom_in_backward:
    # Enable gradient checkpointing
    # Use gradient accumulation with smaller batch
    # Check for memory leaks (retained computation graphs)

if oom_in_optimizer:
    # Use 8-bit optimizer
    # Enable optimizer state offloading

Memory Monitoring During Training

class MemoryMonitorCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"Step {state.global_step}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")

        if allocated > 0.9 * torch.cuda.get_device_properties(0).total_memory / 1e9:
            print("WARNING: Approaching memory limit!")

Hardware-Specific Tips

Consumer GPUs (RTX 3090, 4090)

  • Max 24GB VRAM
  • Use QLoRA for 7B+ models
  • Flash Attention is critical
  • Consider gradient checkpointing always

Professional GPUs (A100)

  • 40GB or 80GB options
  • Can often skip quantization for 7B models
  • Full fine-tuning possible for smaller models
  • Multi-GPU beneficial for 70B+ models

Memory per GPU Generation

GPUVRAMBandwidthBest For
RTX 308010GB760 GB/sInference, small training
RTX 309024GB936 GB/sQLoRA 7B-13B
RTX 409024GB1008 GB/sQLoRA 7B-13B, faster
A100 40GB40GB1555 GB/sLoRA 7B-13B, QLoRA 70B
A100 80GB80GB2039 GB/sFull FT 7B, LoRA 70B
H10080GB3350 GB/sEverything, faster

References

  1. Rajbhandari, S., et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models." arXiv:1910.02054

  2. Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." arXiv:2307.08691

  3. Dettmers, T., et al. (2023). "QLoRA: Efficient Finetuning of Quantized LLMs." arXiv:2305.14314

  4. Micikevicius, P., et al. (2018). "Mixed Precision Training." arXiv:1710.03740

  5. Chen, T., et al. (2016). "Training Deep Nets with Sublinear Memory Cost." arXiv:1604.06174

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel