Back to all articles
Distributed Training

PyTorch FSDP vs DeepSpeed ZeRO: Which Sharding Strategy Wins?

Head-to-head comparison of PyTorch FSDP and DeepSpeed ZeRO. Covers performance benchmarks, feature differences, and guidance on when to use each for distributed LLM training.

Flash Attention TeamJanuary 8, 20267 min read
FSDPDeepSpeedZeROdistributed trainingPyTorchmodel sharding

FSDP and DeepSpeed ZeRO are the two dominant solutions for training models that don't fit on a single GPU. This guide provides a detailed comparison to help you choose the right tool for your training workload.

Feature Comparison

FeatureFSDPDeepSpeed ZeRO
Native PyTorchYesNo (separate library)
Sharding GranularityLayer-levelParameter-level
CPU OffloadingYesYes (more flexible)
NVMe OffloadingNoYes (ZeRO-Infinity)
Mixed PrecisionBF16, FP16BF16, FP16, FP8
Activation CheckpointingPyTorch nativeBuilt-in
Communication OverlapYesYes
torch.compile SupportYes (2.0+)Limited

Architecture Differences

FSDP Design

FSDP shards at the module level (typically transformer layers):

# FSDP wraps entire modules
class FSDP_Model:
    layer_0: FSDP(TransformerLayer)  # Sharded as unit
    layer_1: FSDP(TransformerLayer)
    ...

# During forward:
# 1. All-gather layer_0 weights
# 2. Forward through layer_0
# 3. Free layer_0 gathered weights
# 4. All-gather layer_1 weights
# ...

DeepSpeed ZeRO Design

ZeRO shards at the parameter level with more granular control:

# ZeRO partitions individual parameters
class ZeRO_Model:
    param_1: Partition 0  # Could be partial weight matrix
    param_2: Partition 1
    ...

# Uses "partitioned parameters" that gather on-demand
# Can prefetch next parameters while computing current

Key Architectural Differences

AspectFSDPDeepSpeed ZeRO
Sharding Unitnn.ModuleIndividual parameters
Memory ManagementPyTorch allocatorCustom allocator
CommunicationNCCL via PyTorchNCCL direct
PrefetchingPer-modulePer-parameter

Performance Benchmarks

Single Node (8× A100 80GB)

Training throughput (tokens/second):

ModelFSDPZeRO-2ZeRO-3
LLaMA-7B10,50011,2009,800
LLaMA-13B6,2006,8006,100
LLaMA-30B2,8003,1002,900
LLaMA-70B1,1001,0501,000

Multi-Node (4 nodes × 8 GPUs)

ModelFSDPZeRO-2ZeRO-3
LLaMA-70B4,2004,5004,100
LLaMA-70B + offloadN/A2,8003,200

Memory Efficiency

Peak memory per GPU (BF16, batch=4):

ModelFSDPZeRO-2ZeRO-3
LLaMA-7B35 GB38 GB28 GB
LLaMA-13B48 GB52 GB38 GB
LLaMA-70B72 GB74 GB62 GB

Configuration Examples

FSDP Configuration

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
    CPUOffload,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Define wrapping policy
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer},
)

# Mixed precision
mixed_precision = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

# Wrap model
model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    mixed_precision=mixed_precision,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    cpu_offload=CPUOffload(offload_params=False),
    device_id=torch.cuda.current_device(),
    limit_all_gathers=True,
    use_orig_params=True,  # Required for torch.compile
)

DeepSpeed ZeRO-2 Configuration

{
    "bf16": {"enabled": true},

    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": true,
        "allgather_bucket_size": 5e8,
        "reduce_scatter": true,
        "reduce_bucket_size": 5e8,
        "overlap_comm": true,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": 1.0,

    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto"
}

DeepSpeed ZeRO-3 Configuration

{
    "bf16": {"enabled": true},

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {"device": "none"},
        "offload_param": {"device": "none"},
        "overlap_comm": true,
        "contiguous_gradients": true,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto"
    }
}

Use Case Recommendations

When to Use FSDP

  1. Native PyTorch Integration

    • Using torch.compile for additional speedups
    • Want minimal external dependencies
    • Building on latest PyTorch features
  2. Simple Distributed Training

    • Single-node multi-GPU training
    • Standard transformer architectures
    • Don't need advanced offloading
  3. Ecosystem Compatibility

    • PyTorch Lightning integration
    • torchrun for launching
    • PyTorch profiler
# FSDP + torch.compile example
model = FSDP(model, use_orig_params=True)
model = torch.compile(model)  # Works with FSDP

When to Use DeepSpeed

  1. Memory-Constrained Training

    • Need CPU/NVMe offloading
    • Training very large models (100B+)
    • Limited GPU memory
  2. Multi-Node Scale

    • Training across many nodes
    • Need optimized communication
    • Complex parallelism strategies
  3. Advanced Features

    • ZeRO-Infinity for trillion-parameter models
    • Sparse attention support
    • Custom optimizers (1-bit Adam)
# DeepSpeed with CPU offloading
ds_config = {
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {"device": "cpu", "pin_memory": True},
        "offload_param": {"device": "cpu", "pin_memory": True}
    }
}

Integration with Hugging Face

FSDP with Trainer

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./output",
    fsdp="full_shard auto_wrap",
    fsdp_config={
        "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],
        "fsdp_backward_prefetch": "backward_pre",
        "fsdp_forward_prefetch": True,
        "fsdp_use_orig_params": True,
    },
    bf16=True,
)

DeepSpeed with Trainer

training_args = TrainingArguments(
    output_dir="./output",
    deepspeed="ds_config.json",
    bf16=True,
)

# Launch with
# deepspeed --num_gpus=8 train.py

Accelerate Integration

from accelerate import Accelerator

# FSDP
accelerator = Accelerator(
    mixed_precision="bf16",
    fsdp_plugin=fsdp_plugin,
)

# DeepSpeed
accelerator = Accelerator(
    mixed_precision="bf16",
    deepspeed_plugin=deepspeed_plugin,
)

Common Migration Scenarios

DDP to FSDP

# Before (DDP)
model = DDP(model, device_ids=[rank])

# After (FSDP)
model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
)

FSDP to DeepSpeed

# Before (FSDP)
model = FSDP(model, ...)
optimizer = torch.optim.AdamW(model.parameters())

# After (DeepSpeed)
model, optimizer, _, _ = deepspeed.initialize(
    model=model,
    optimizer=optimizer,
    config=ds_config,
)

Troubleshooting

FSDP Common Issues

# Issue: OOM during all-gather
# Solution: Enable limit_all_gathers
model = FSDP(model, limit_all_gathers=True)

# Issue: Slow with many small modules
# Solution: Increase min_num_params
auto_wrap_policy = functools.partial(
    size_based_auto_wrap_policy,
    min_num_params=1e7,  # Wrap modules with >10M params
)

# Issue: Checkpoint loading fails
# Solution: Use FSDP's state dict type
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
    state_dict = model.state_dict()

DeepSpeed Common Issues

# Issue: Slow ZeRO-3 training
# Solution: Tune prefetch settings
"stage3_prefetch_bucket_size": 5e7,  # Smaller buckets
"stage3_param_persistence_threshold": 1e5,  # Keep small params

# Issue: CPU offload OOM
# Solution: Pin memory and tune buffer size
"offload_optimizer": {
    "device": "cpu",
    "pin_memory": True,
    "buffer_count": 4
}

Decision Matrix

ScenarioRecommendation
7B model, single nodeFSDP (simpler)
70B model, 8 GPUsEither works well
70B model, need offloadingDeepSpeed ZeRO-3
Want torch.compileFSDP
Multi-node, 100B+DeepSpeed ZeRO-3
PyTorch ecosystem priorityFSDP
Maximum memory efficiencyDeepSpeed ZeRO-3

References

  1. Zhao, Y., et al. (2023). "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel." arXiv:2304.11277

  2. SimpleFSDP. (2024). "SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile." arXiv:2411.00284

  3. Rajbhandari, S., et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models." arXiv:1910.02054

  4. Rajbhandari, S., et al. (2021). "ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning." arXiv:2104.07857

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel