Back to all articles
GPU Optimization

Mixed Precision Training with FP16 and BF16: When and How to Use It

Complete guide to mixed precision training in PyTorch. Learn the differences between FP16 and BF16, when to use each, and how to implement stable training with code examples.

Flash Attention TeamJanuary 8, 20268 min read
mixed precisionFP16BF16AMPPyTorch trainingGPU optimization

Mixed precision training reduces memory usage and speeds up training by using lower-precision formats for most operations. This guide explains the differences between FP16 and BF16, and how to implement mixed precision correctly.

Understanding Floating Point Formats

Format Comparison

FormatSignExponentMantissaRangePrecision
FP321 bit8 bits23 bits±3.4×10³⁸~7 decimal digits
FP161 bit5 bits10 bits±65,504~3 decimal digits
BF161 bit8 bits7 bits±3.4×10³⁸~2 decimal digits

Key Differences

FP16 (Float16):

  • High precision (10 mantissa bits)
  • Limited range (max ~65,504)
  • Requires loss scaling to prevent underflow
  • Supported on all CUDA GPUs

BF16 (BFloat16):

  • Lower precision (7 mantissa bits)
  • Same range as FP32
  • No scaling needed
  • Requires Ampere+ GPUs (A100, RTX 30/40 series)
import torch

# Check BF16 support
print(f"BF16 supported: {torch.cuda.is_bf16_supported()}")

# Format examples
fp32_val = torch.tensor(0.0001, dtype=torch.float32)
fp16_val = fp32_val.to(torch.float16)  # May underflow to 0
bf16_val = fp32_val.to(torch.bfloat16)  # Keeps value (lower precision)

Why Mixed Precision Works

Mixed precision uses lower precision where it doesn't hurt accuracy:

Forward Pass:  FP16/BF16 (fast, memory efficient)
Backward Pass: FP16/BF16 (fast, memory efficient)
Master Weights: FP32 (maintains accuracy)
Optimizer Step: FP32 (numerical stability)

Memory Savings

ComponentFP32Mixed Precision
Model weights4 bytes/param2 bytes/param (+ 4 bytes master)
Gradients4 bytes/param2 bytes/param
Activations4 bytes/value2 bytes/value
7B model total~56 GB~35 GB

Speed Improvements

Modern GPUs have dedicated hardware for lower precision:

GPUFP32 TFLOPSFP16 TFLOPSBF16 TFLOPS
RTX 30903671N/A
RTX 409083165165
A10019.5312312
H100671,9791,979

FP16 Training with Loss Scaling

FP16's limited range can cause gradient underflow. Loss scaling prevents this:

from torch.cuda.amp import autocast, GradScaler

# Initialize scaler
scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()

    # Forward pass in FP16
    with autocast(dtype=torch.float16):
        outputs = model(batch['input_ids'])
        loss = criterion(outputs, batch['labels'])

    # Scale loss to prevent gradient underflow
    scaler.scale(loss).backward()

    # Unscale gradients and clip
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # Optimizer step (skips if inf/nan gradients)
    scaler.step(optimizer)
    scaler.update()

How Loss Scaling Works

Without scaling:
  Gradient = 0.00001 → FP16: 0 (underflow!)

With scaling (scale=1024):
  Scaled gradient = 0.00001 × 1024 = 0.01024 → FP16: 0.01025 ✓
  After unscaling = 0.01025 / 1024 = 0.00001 ✓

The GradScaler dynamically adjusts the scale factor:

  • Increases scale when gradients are healthy
  • Decreases scale when overflow (inf/nan) detected
  • Skips optimizer step on overflow

BF16 is simpler because it doesn't need scaling:

# BF16 training - no scaler needed!
for batch in dataloader:
    optimizer.zero_grad()

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        outputs = model(batch['input_ids'])
        loss = criterion(outputs, batch['labels'])

    loss.backward()
    optimizer.step()

When to Choose BF16 vs FP16

ScenarioRecommendation
Ampere+ GPU (A100, RTX 30/40, H100)BF16
Older GPU (V100, RTX 20 series)FP16 with scaling
Numerical stability issues with FP16BF16
Maximum precision neededFP16 (more mantissa bits)
Training large language modelsBF16

Hugging Face Integration

TrainingArguments

from transformers import TrainingArguments

# FP16 training
training_args = TrainingArguments(
    output_dir="./output",
    fp16=True,
    fp16_full_eval=True,
)

# BF16 training (recommended for Ampere+)
training_args = TrainingArguments(
    output_dir="./output",
    bf16=True,
    bf16_full_eval=True,
)

Automatic Selection

import torch

def get_mixed_precision_dtype():
    """Select best dtype for current hardware."""
    if torch.cuda.is_bf16_supported():
        return "bf16"
    elif torch.cuda.is_available():
        return "fp16"
    return "no"  # CPU fallback

training_args = TrainingArguments(
    output_dir="./output",
    **{get_mixed_precision_dtype(): True},
)

Operations That Require FP32

Some operations are numerically unstable in lower precision:

# Operations that should stay in FP32
fp32_operations = [
    "softmax",      # Exponentials can overflow
    "layer_norm",   # Needs precision for statistics
    "loss_fn",      # Log operations sensitive
    "optimizer",    # Weight updates need precision
]

# PyTorch autocast handles this automatically
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    # These automatically run in FP32:
    x = F.softmax(logits, dim=-1)
    x = F.layer_norm(x, normalized_shape)
    loss = F.cross_entropy(x, targets)

Custom Autocast Rules

# Force specific operations to FP32
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def stable_softmax(x):
    return F.softmax(x, dim=-1)

# Or use autocast context manager
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    x = model.forward_half(inputs)

    # Disable autocast for sensitive operations
    with torch.autocast(device_type='cuda', enabled=False):
        x = x.float()
        loss = sensitive_loss_function(x)

Common Issues and Solutions

Issue 1: Loss Goes to NaN

Cause: Gradient overflow or unstable operations

# Solution 1: Use BF16 instead of FP16
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    ...

# Solution 2: Adjust GradScaler settings
scaler = GradScaler(
    init_scale=2**10,       # Lower initial scale (default 2**16)
    growth_interval=1000,   # Slower scale growth
)

# Solution 3: Gradient clipping before unscaling
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
scaler.step(optimizer)

Issue 2: Model Converges Slower

Cause: Loss of precision in gradients

# Solution: Keep master weights in FP32
# This is automatic with PyTorch AMP, but verify:
for param in model.parameters():
    assert param.dtype == torch.float32  # Master weights

# During forward/backward, copies are made in FP16/BF16

Issue 3: Inference Differs from Training

Cause: Not using autocast during inference

# Match training precision during inference
model.eval()
with torch.no_grad():
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        outputs = model(inputs)

Performance Benchmarks

Training Throughput (samples/second)

LLaMA-7B on A100 80GB:

PrecisionThroughputMemoryvs FP32
FP3212.475 GB1.0x
FP16 + scaling28.642 GB2.3x
BF1629.242 GB2.4x

Memory Comparison

Model SizeFP32Mixed PrecisionSavings
1B8 GB4.5 GB44%
7B56 GB32 GB43%
13B104 GB60 GB42%
70B560 GB320 GB43%

Advanced: FP8 Training

H100 GPUs support FP8 for even more performance:

# FP8 (Hopper GPUs only)
# Currently experimental in PyTorch

# Two FP8 formats:
# E4M3: 4 exponent bits, 3 mantissa - for forward pass
# E5M2: 5 exponent bits, 2 mantissa - for gradients

from torch.ao.quantization import quantize_dynamic

# FP8 support is evolving - check latest PyTorch docs

Complete Training Example

import torch
from torch.cuda.amp import autocast, GradScaler
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW

# Setup
device = torch.device("cuda")
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float32,  # Master weights in FP32
).to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)

# Choose precision based on hardware
use_bf16 = torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if use_bf16 else torch.float16
scaler = GradScaler(enabled=not use_bf16)  # Only for FP16

# Training loop
model.train()
for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()

        # Mixed precision forward pass
        with autocast(device_type='cuda', dtype=dtype):
            outputs = model(**batch)
            loss = outputs.loss

        # Backward pass
        if use_bf16:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        else:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

References

  1. Micikevicius, P., et al. (2018). "Mixed Precision Training." arXiv:1710.03740

  2. NVIDIA. (2025). "Automatic Mixed Precision." NVIDIA Documentation

  3. Kalamkar, D., et al. (2019). "A Study of BFLOAT16 for Deep Learning Training." arXiv:1905.12322

  4. PyTorch. (2025). "Automatic Mixed Precision." PyTorch Documentation

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel