Back to all articles
Distributed Training

Distributed Training for Large Models: DDP, FSDP, and DeepSpeed Explained

Complete guide to distributed training in PyTorch. Learn DDP, FSDP, and DeepSpeed ZeRO with practical examples, memory analysis, and scaling strategies for training models from 1B to 100B+ parameters.

Flash Attention TeamJanuary 8, 202610 min read
distributed trainingDDPFSDPDeepSpeedZeROmulti-GPUPyTorch

Training large models requires distributing work across multiple GPUs. This comprehensive guide explains the three main approaches—DDP, FSDP, and DeepSpeed—with practical implementation details and performance analysis.

The Distributed Training Landscape

MethodWhen to UseMemory EfficiencySetup Complexity
DDPModel fits on 1 GPULowSimple
FSDPModel doesn't fit on 1 GPUHighMedium
DeepSpeed ZeRO-2Large models, medium efficiencyMediumMedium
DeepSpeed ZeRO-3Very large modelsVery HighComplex

Data Distributed Parallel (DDP)

How DDP Works

DDP is the simplest distributed strategy: replicate the full model on each GPU and split data across GPUs:

GPU 0: Full model copy + Batch 0 → Gradients 0
GPU 1: Full model copy + Batch 1 → Gradients 1
GPU 2: Full model copy + Batch 2 → Gradients 2
GPU 3: Full model copy + Batch 3 → Gradients 3
           ↓ All-Reduce gradients
All GPUs: Averaged gradients → Update weights

DDP Implementation

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)

    # Create model and move to GPU
    model = MyModel().to(rank)
    model = DDP(model, device_ids=[rank])

    # Create distributed sampler
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank
    )
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=32)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)  # Shuffle differently each epoch

        for batch in dataloader:
            optimizer.zero_grad()
            loss = model(batch.to(rank))
            loss.backward()
            optimizer.step()

    cleanup()

# Launch
import torch.multiprocessing as mp
mp.spawn(train, args=(world_size,), nprocs=world_size)

DDP with torchrun

# Single node, 4 GPUs
torchrun --nproc_per_node=4 train.py

# Multi-node (node 0)
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
    --master_addr=10.0.0.1 --master_port=29500 train.py

# Multi-node (node 1)
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
    --master_addr=10.0.0.1 --master_port=29500 train.py

DDP Memory Analysis

Per-GPU Memory = Model + Optimizer States + Gradients + Activations

For 7B model with AdamW:
- Model (FP16): 14 GB
- Optimizer (FP32): 56 GB (momentum + variance + master weights)
- Gradients (FP16): 14 GB
- Activations: ~10-20 GB
Total: ~95-105 GB per GPU

DDP requires the entire model to fit on each GPU—no memory savings from distribution.

Fully Sharded Data Parallel (FSDP)

How FSDP Works

FSDP shards model parameters, gradients, and optimizer states across GPUs:

Before forward pass:
GPU 0: Shard 0 of model
GPU 1: Shard 1 of model
GPU 2: Shard 2 of model
GPU 3: Shard 3 of model

During forward pass (per layer):
1. All-Gather: Collect full layer weights from all GPUs
2. Compute: Forward pass with full weights
3. Discard: Free gathered weights (keep only local shard)

During backward pass (per layer):
1. All-Gather: Collect full layer weights
2. Compute: Backward pass
3. Reduce-Scatter: Each GPU gets gradient shard
4. Update: Each GPU updates its parameter shard

FSDP Implementation

import torch
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import LlamaDecoderLayer

def train_fsdp(rank, world_size):
    setup(rank, world_size)

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-hf",
        torch_dtype=torch.bfloat16,
    )

    # FSDP wrapping policy (wrap each transformer layer)
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={LlamaDecoderLayer},
    )

    # Mixed precision config
    mixed_precision = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16,
    )

    # Wrap with FSDP
    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mixed_precision,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        device_id=rank,
        limit_all_gathers=True,
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

    for batch in dataloader:
        optimizer.zero_grad()
        loss = model(**batch).loss
        loss.backward()
        optimizer.step()

FSDP Sharding Strategies

StrategyMemoryCommunicationUse Case
FULL_SHARDBestHighestModel >> GPU memory
SHARD_GRAD_OPMediumMediumBalanced
NO_SHARDLowestLowestModel fits (like DDP)
HYBRID_SHARDConfigurableConfigurableMulti-node

FSDP Memory Analysis

With FULL_SHARD and 4 GPUs:

Per-GPU Memory = Model/4 + OptState/4 + Gradients/4 + Activations + Temp

For 7B model:
- Sharded model (FP16): 14/4 = 3.5 GB
- Sharded optimizer: 56/4 = 14 GB
- Sharded gradients: 14/4 = 3.5 GB
- Full layer weights (temp): ~1 GB
- Activations: ~10-20 GB
Total: ~32-45 GB per GPU (vs ~95 GB for DDP)

FSDP with Hugging Face Trainer

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./output",
    fsdp="full_shard auto_wrap",
    fsdp_config={
        "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],
        "fsdp_min_num_params": 1e6,
        "fsdp_offload_params": False,
        "fsdp_sharding_strategy": "FULL_SHARD",
    },
    bf16=True,
    per_device_train_batch_size=4,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

DeepSpeed ZeRO

ZeRO Optimization Stages

DeepSpeed's ZeRO (Zero Redundancy Optimizer) has three stages:

StageShardsMemory Reduction
ZeRO-1Optimizer states~4x
ZeRO-2+ Gradients~8x
ZeRO-3+ Parameters~N× (N = GPUs)

ZeRO Memory Math

For 7B model with 4 GPUs:

ComponentNo ZeROZeRO-1ZeRO-2ZeRO-3
Parameters (BF16)14 GB14 GB14 GB3.5 GB
Gradients (BF16)14 GB14 GB3.5 GB3.5 GB
Optimizer (FP32)56 GB14 GB14 GB14 GB
Total84 GB42 GB31.5 GB21 GB

DeepSpeed Configuration

// ds_config.json
{
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto",

    "bf16": {
        "enabled": true
    },

    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "none"
        },
        "offload_param": {
            "device": "none"
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 5e8,
        "reduce_scatter": true,
        "reduce_bucket_size": 5e8,
        "overlap_comm": true,
        "contiguous_gradients": true
    },

    "gradient_clipping": 1.0
}

DeepSpeed ZeRO-3 Config

{
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    }
}

DeepSpeed with Hugging Face

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./output",
    deepspeed="ds_config.json",
    bf16=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

Launch DeepSpeed

# Single node
deepspeed --num_gpus=4 train.py --deepspeed ds_config.json

# Multi-node with hostfile
deepspeed --hostfile=hostfile train.py --deepspeed ds_config.json

# hostfile format:
# node1 slots=4
# node2 slots=4

Communication Patterns

Understanding Collectives

CollectiveOperationUse in Training
All-ReduceSum across GPUs, result to allDDP gradient sync
All-GatherCollect from all to allFSDP weight gathering
Reduce-ScatterSum and distribute shardsFSDP gradient sync
BroadcastOne to allWeight initialization

Communication Overhead

# Communication volume analysis

# DDP All-Reduce (gradients)
comm_volume_ddp = 2 * model_size  # ring all-reduce

# FSDP per layer
comm_volume_fsdp_forward = layer_size * (world_size - 1) / world_size  # all-gather
comm_volume_fsdp_backward = layer_size * 2  # all-gather + reduce-scatter

# For 7B model, 32 layers, 4 GPUs
# DDP: 2 * 14GB = 28 GB total
# FSDP: 32 layers * ~1.5 * layer_size = ~21 GB total

Choosing the Right Strategy

Decision Tree

Model fits on single GPU?
├── Yes → Use DDP (simplest, fastest)
└── No
    ├── Training 7B-13B model?
    │   └── FSDP or ZeRO-2
    ├── Training 30B-70B model?
    │   └── FSDP or ZeRO-3
    └── Training 100B+ model?
        └── ZeRO-3 with offloading + tensor parallelism

Performance Comparison (8× A100 80GB)

ModelMethodThroughputMemory/GPU
7BDDP12,000 tok/s72 GB
7BFSDP10,500 tok/s35 GB
7BZeRO-211,000 tok/s38 GB
13BDDPOOM-
13BFSDP6,200 tok/s48 GB
13BZeRO-26,800 tok/s52 GB
70BFSDP1,100 tok/s72 GB
70BZeRO-31,000 tok/s65 GB

Recommendations

Model SizeGPUsStrategyNotes
<7B1-8DDPSimplest, fastest
7B-13B4-8FSDP/ZeRO-2Good balance
13B-70B8-32FSDP/ZeRO-3May need activation checkpointing
70B+32+ZeRO-3 + TPConsider tensor parallelism

Advanced: Combining Parallelism

3D Parallelism

For very large models, combine multiple strategies:

3D Parallelism = Data Parallel × Tensor Parallel × Pipeline Parallel

Example for 175B model on 64 GPUs:
- Tensor Parallel: 8 (split attention heads)
- Pipeline Parallel: 2 (split layers)
- Data Parallel: 4 (replicate pipeline)

64 = 8 × 2 × 4

Tensor Parallelism

Split individual layers across GPUs:

# Column-parallel linear (split output)
class ColumnParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size):
        self.local_out = out_features // world_size
        self.weight = nn.Parameter(torch.randn(self.local_out, in_features))

    def forward(self, x):
        # Each GPU computes partial output
        local_output = F.linear(x, self.weight)
        # Gather for full output
        return all_gather(local_output)

Megatron-LM Style Parallelism

# Attention with tensor parallelism
class TensorParallelAttention(nn.Module):
    def __init__(self, config, tp_size):
        self.num_heads_per_partition = config.num_heads // tp_size

        # Column parallel for Q, K, V
        self.qkv = ColumnParallelLinear(
            config.hidden_size,
            3 * config.hidden_size,
            tp_size
        )

        # Row parallel for output projection
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            tp_size
        )

Troubleshooting

Common Issues

1. NCCL Timeout

# Increase timeout for slow networks
export NCCL_TIMEOUT=1800  # 30 minutes

# Debug NCCL issues
export NCCL_DEBUG=INFO

2. OOM During All-Gather

# Use limit_all_gathers in FSDP
model = FSDP(model, limit_all_gathers=True)

# Or use smaller bucket size in DeepSpeed
"allgather_bucket_size": 2e8  # Reduce from 5e8

3. Slow Multi-Node Training

# Ensure high-bandwidth interconnect
# Check InfiniBand/RoCE status
ibstat

# Use NCCL environment variables
export NCCL_IB_DISABLE=0
export NCCL_NET_GDR_LEVEL=5

References

  1. Rajbhandari, S., et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models." arXiv:1910.02054

  2. Zhao, Y., et al. (2023). "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel." arXiv:2304.11277

  3. SimpleFSDP. (2024). "SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile." arXiv:2411.00284

  4. Shoeybi, M., et al. (2020). "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism." arXiv:1909.08053

  5. Microsoft. (2025). "DeepSpeed Documentation." GitHub

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel