Mixed precision training reduces memory usage and speeds up training by using lower-precision formats for most operations. This guide explains the differences between FP16 and BF16, and how to implement mixed precision correctly.
Understanding Floating Point Formats
Format Comparison
| Format | Sign | Exponent | Mantissa | Range | Precision |
|---|---|---|---|---|---|
| FP32 | 1 bit | 8 bits | 23 bits | ±3.4×10³⁸ | ~7 decimal digits |
| FP16 | 1 bit | 5 bits | 10 bits | ±65,504 | ~3 decimal digits |
| BF16 | 1 bit | 8 bits | 7 bits | ±3.4×10³⁸ | ~2 decimal digits |
Key Differences
FP16 (Float16):
- High precision (10 mantissa bits)
- Limited range (max ~65,504)
- Requires loss scaling to prevent underflow
- Supported on all CUDA GPUs
BF16 (BFloat16):
- Lower precision (7 mantissa bits)
- Same range as FP32
- No scaling needed
- Requires Ampere+ GPUs (A100, RTX 30/40 series)
import torch
# Check BF16 support
print(f"BF16 supported: {torch.cuda.is_bf16_supported()}")
# Format examples
fp32_val = torch.tensor(0.0001, dtype=torch.float32)
fp16_val = fp32_val.to(torch.float16) # May underflow to 0
bf16_val = fp32_val.to(torch.bfloat16) # Keeps value (lower precision)
Why Mixed Precision Works
Mixed precision uses lower precision where it doesn't hurt accuracy:
Forward Pass: FP16/BF16 (fast, memory efficient)
Backward Pass: FP16/BF16 (fast, memory efficient)
Master Weights: FP32 (maintains accuracy)
Optimizer Step: FP32 (numerical stability)
Memory Savings
| Component | FP32 | Mixed Precision |
|---|---|---|
| Model weights | 4 bytes/param | 2 bytes/param (+ 4 bytes master) |
| Gradients | 4 bytes/param | 2 bytes/param |
| Activations | 4 bytes/value | 2 bytes/value |
| 7B model total | ~56 GB | ~35 GB |
Speed Improvements
Modern GPUs have dedicated hardware for lower precision:
| GPU | FP32 TFLOPS | FP16 TFLOPS | BF16 TFLOPS |
|---|---|---|---|
| RTX 3090 | 36 | 71 | N/A |
| RTX 4090 | 83 | 165 | 165 |
| A100 | 19.5 | 312 | 312 |
| H100 | 67 | 1,979 | 1,979 |
FP16 Training with Loss Scaling
FP16's limited range can cause gradient underflow. Loss scaling prevents this:
from torch.cuda.amp import autocast, GradScaler
# Initialize scaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in FP16
with autocast(dtype=torch.float16):
outputs = model(batch['input_ids'])
loss = criterion(outputs, batch['labels'])
# Scale loss to prevent gradient underflow
scaler.scale(loss).backward()
# Unscale gradients and clip
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step (skips if inf/nan gradients)
scaler.step(optimizer)
scaler.update()
How Loss Scaling Works
Without scaling:
Gradient = 0.00001 → FP16: 0 (underflow!)
With scaling (scale=1024):
Scaled gradient = 0.00001 × 1024 = 0.01024 → FP16: 0.01025 ✓
After unscaling = 0.01025 / 1024 = 0.00001 ✓
The GradScaler dynamically adjusts the scale factor:
- Increases scale when gradients are healthy
- Decreases scale when overflow (inf/nan) detected
- Skips optimizer step on overflow
BF16 Training (Recommended)
BF16 is simpler because it doesn't need scaling:
# BF16 training - no scaler needed!
for batch in dataloader:
optimizer.zero_grad()
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
outputs = model(batch['input_ids'])
loss = criterion(outputs, batch['labels'])
loss.backward()
optimizer.step()
When to Choose BF16 vs FP16
| Scenario | Recommendation |
|---|---|
| Ampere+ GPU (A100, RTX 30/40, H100) | BF16 |
| Older GPU (V100, RTX 20 series) | FP16 with scaling |
| Numerical stability issues with FP16 | BF16 |
| Maximum precision needed | FP16 (more mantissa bits) |
| Training large language models | BF16 |
Hugging Face Integration
TrainingArguments
from transformers import TrainingArguments
# FP16 training
training_args = TrainingArguments(
output_dir="./output",
fp16=True,
fp16_full_eval=True,
)
# BF16 training (recommended for Ampere+)
training_args = TrainingArguments(
output_dir="./output",
bf16=True,
bf16_full_eval=True,
)
Automatic Selection
import torch
def get_mixed_precision_dtype():
"""Select best dtype for current hardware."""
if torch.cuda.is_bf16_supported():
return "bf16"
elif torch.cuda.is_available():
return "fp16"
return "no" # CPU fallback
training_args = TrainingArguments(
output_dir="./output",
**{get_mixed_precision_dtype(): True},
)
Operations That Require FP32
Some operations are numerically unstable in lower precision:
# Operations that should stay in FP32
fp32_operations = [
"softmax", # Exponentials can overflow
"layer_norm", # Needs precision for statistics
"loss_fn", # Log operations sensitive
"optimizer", # Weight updates need precision
]
# PyTorch autocast handles this automatically
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
# These automatically run in FP32:
x = F.softmax(logits, dim=-1)
x = F.layer_norm(x, normalized_shape)
loss = F.cross_entropy(x, targets)
Custom Autocast Rules
# Force specific operations to FP32
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def stable_softmax(x):
return F.softmax(x, dim=-1)
# Or use autocast context manager
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
x = model.forward_half(inputs)
# Disable autocast for sensitive operations
with torch.autocast(device_type='cuda', enabled=False):
x = x.float()
loss = sensitive_loss_function(x)
Common Issues and Solutions
Issue 1: Loss Goes to NaN
Cause: Gradient overflow or unstable operations
# Solution 1: Use BF16 instead of FP16
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
...
# Solution 2: Adjust GradScaler settings
scaler = GradScaler(
init_scale=2**10, # Lower initial scale (default 2**16)
growth_interval=1000, # Slower scale growth
)
# Solution 3: Gradient clipping before unscaling
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
scaler.step(optimizer)
Issue 2: Model Converges Slower
Cause: Loss of precision in gradients
# Solution: Keep master weights in FP32
# This is automatic with PyTorch AMP, but verify:
for param in model.parameters():
assert param.dtype == torch.float32 # Master weights
# During forward/backward, copies are made in FP16/BF16
Issue 3: Inference Differs from Training
Cause: Not using autocast during inference
# Match training precision during inference
model.eval()
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
outputs = model(inputs)
Performance Benchmarks
Training Throughput (samples/second)
LLaMA-7B on A100 80GB:
| Precision | Throughput | Memory | vs FP32 |
|---|---|---|---|
| FP32 | 12.4 | 75 GB | 1.0x |
| FP16 + scaling | 28.6 | 42 GB | 2.3x |
| BF16 | 29.2 | 42 GB | 2.4x |
Memory Comparison
| Model Size | FP32 | Mixed Precision | Savings |
|---|---|---|---|
| 1B | 8 GB | 4.5 GB | 44% |
| 7B | 56 GB | 32 GB | 43% |
| 13B | 104 GB | 60 GB | 42% |
| 70B | 560 GB | 320 GB | 43% |
Advanced: FP8 Training
H100 GPUs support FP8 for even more performance:
# FP8 (Hopper GPUs only)
# Currently experimental in PyTorch
# Two FP8 formats:
# E4M3: 4 exponent bits, 3 mantissa - for forward pass
# E5M2: 5 exponent bits, 2 mantissa - for gradients
from torch.ao.quantization import quantize_dynamic
# FP8 support is evolving - check latest PyTorch docs
Complete Training Example
import torch
from torch.cuda.amp import autocast, GradScaler
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW
# Setup
device = torch.device("cuda")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float32, # Master weights in FP32
).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)
# Choose precision based on hardware
use_bf16 = torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if use_bf16 else torch.float16
scaler = GradScaler(enabled=not use_bf16) # Only for FP16
# Training loop
model.train()
for epoch in range(num_epochs):
for batch in dataloader:
optimizer.zero_grad()
# Mixed precision forward pass
with autocast(device_type='cuda', dtype=dtype):
outputs = model(**batch)
loss = outputs.loss
# Backward pass
if use_bf16:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
else:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
References
-
Micikevicius, P., et al. (2018). "Mixed Precision Training." arXiv:1710.03740
-
NVIDIA. (2025). "Automatic Mixed Precision." NVIDIA Documentation
-
Kalamkar, D., et al. (2019). "A Study of BFLOAT16 for Deep Learning Training." arXiv:1905.12322
-
PyTorch. (2025). "Automatic Mixed Precision." PyTorch Documentation