FSDP and DeepSpeed ZeRO are the two dominant solutions for training models that don't fit on a single GPU. This guide provides a detailed comparison to help you choose the right tool for your training workload.
Feature Comparison
| Feature | FSDP | DeepSpeed ZeRO |
|---|---|---|
| Native PyTorch | Yes | No (separate library) |
| Sharding Granularity | Layer-level | Parameter-level |
| CPU Offloading | Yes | Yes (more flexible) |
| NVMe Offloading | No | Yes (ZeRO-Infinity) |
| Mixed Precision | BF16, FP16 | BF16, FP16, FP8 |
| Activation Checkpointing | PyTorch native | Built-in |
| Communication Overlap | Yes | Yes |
| torch.compile Support | Yes (2.0+) | Limited |
Architecture Differences
FSDP Design
FSDP shards at the module level (typically transformer layers):
# FSDP wraps entire modules
class FSDP_Model:
layer_0: FSDP(TransformerLayer) # Sharded as unit
layer_1: FSDP(TransformerLayer)
...
# During forward:
# 1. All-gather layer_0 weights
# 2. Forward through layer_0
# 3. Free layer_0 gathered weights
# 4. All-gather layer_1 weights
# ...
DeepSpeed ZeRO Design
ZeRO shards at the parameter level with more granular control:
# ZeRO partitions individual parameters
class ZeRO_Model:
param_1: Partition 0 # Could be partial weight matrix
param_2: Partition 1
...
# Uses "partitioned parameters" that gather on-demand
# Can prefetch next parameters while computing current
Key Architectural Differences
| Aspect | FSDP | DeepSpeed ZeRO |
|---|---|---|
| Sharding Unit | nn.Module | Individual parameters |
| Memory Management | PyTorch allocator | Custom allocator |
| Communication | NCCL via PyTorch | NCCL direct |
| Prefetching | Per-module | Per-parameter |
Performance Benchmarks
Single Node (8× A100 80GB)
Training throughput (tokens/second):
| Model | FSDP | ZeRO-2 | ZeRO-3 |
|---|---|---|---|
| LLaMA-7B | 10,500 | 11,200 | 9,800 |
| LLaMA-13B | 6,200 | 6,800 | 6,100 |
| LLaMA-30B | 2,800 | 3,100 | 2,900 |
| LLaMA-70B | 1,100 | 1,050 | 1,000 |
Multi-Node (4 nodes × 8 GPUs)
| Model | FSDP | ZeRO-2 | ZeRO-3 |
|---|---|---|---|
| LLaMA-70B | 4,200 | 4,500 | 4,100 |
| LLaMA-70B + offload | N/A | 2,800 | 3,200 |
Memory Efficiency
Peak memory per GPU (BF16, batch=4):
| Model | FSDP | ZeRO-2 | ZeRO-3 |
|---|---|---|---|
| LLaMA-7B | 35 GB | 38 GB | 28 GB |
| LLaMA-13B | 48 GB | 52 GB | 38 GB |
| LLaMA-70B | 72 GB | 74 GB | 62 GB |
Configuration Examples
FSDP Configuration
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
CPUOffload,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Define wrapping policy
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)
# Mixed precision
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
# Wrap model
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision,
sharding_strategy=ShardingStrategy.FULL_SHARD,
cpu_offload=CPUOffload(offload_params=False),
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
use_orig_params=True, # Required for torch.compile
)
DeepSpeed ZeRO-2 Configuration
{
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": true,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto"
}
DeepSpeed ZeRO-3 Configuration
{
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "none"},
"offload_param": {"device": "none"},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto"
}
}
Use Case Recommendations
When to Use FSDP
-
Native PyTorch Integration
- Using torch.compile for additional speedups
- Want minimal external dependencies
- Building on latest PyTorch features
-
Simple Distributed Training
- Single-node multi-GPU training
- Standard transformer architectures
- Don't need advanced offloading
-
Ecosystem Compatibility
- PyTorch Lightning integration
- torchrun for launching
- PyTorch profiler
# FSDP + torch.compile example
model = FSDP(model, use_orig_params=True)
model = torch.compile(model) # Works with FSDP
When to Use DeepSpeed
-
Memory-Constrained Training
- Need CPU/NVMe offloading
- Training very large models (100B+)
- Limited GPU memory
-
Multi-Node Scale
- Training across many nodes
- Need optimized communication
- Complex parallelism strategies
-
Advanced Features
- ZeRO-Infinity for trillion-parameter models
- Sparse attention support
- Custom optimizers (1-bit Adam)
# DeepSpeed with CPU offloading
ds_config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu", "pin_memory": True},
"offload_param": {"device": "cpu", "pin_memory": True}
}
}
Integration with Hugging Face
FSDP with Trainer
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./output",
fsdp="full_shard auto_wrap",
fsdp_config={
"fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],
"fsdp_backward_prefetch": "backward_pre",
"fsdp_forward_prefetch": True,
"fsdp_use_orig_params": True,
},
bf16=True,
)
DeepSpeed with Trainer
training_args = TrainingArguments(
output_dir="./output",
deepspeed="ds_config.json",
bf16=True,
)
# Launch with
# deepspeed --num_gpus=8 train.py
Accelerate Integration
from accelerate import Accelerator
# FSDP
accelerator = Accelerator(
mixed_precision="bf16",
fsdp_plugin=fsdp_plugin,
)
# DeepSpeed
accelerator = Accelerator(
mixed_precision="bf16",
deepspeed_plugin=deepspeed_plugin,
)
Common Migration Scenarios
DDP to FSDP
# Before (DDP)
model = DDP(model, device_ids=[rank])
# After (FSDP)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
)
FSDP to DeepSpeed
# Before (FSDP)
model = FSDP(model, ...)
optimizer = torch.optim.AdamW(model.parameters())
# After (DeepSpeed)
model, optimizer, _, _ = deepspeed.initialize(
model=model,
optimizer=optimizer,
config=ds_config,
)
Troubleshooting
FSDP Common Issues
# Issue: OOM during all-gather
# Solution: Enable limit_all_gathers
model = FSDP(model, limit_all_gathers=True)
# Issue: Slow with many small modules
# Solution: Increase min_num_params
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy,
min_num_params=1e7, # Wrap modules with >10M params
)
# Issue: Checkpoint loading fails
# Solution: Use FSDP's state dict type
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
state_dict = model.state_dict()
DeepSpeed Common Issues
# Issue: Slow ZeRO-3 training
# Solution: Tune prefetch settings
"stage3_prefetch_bucket_size": 5e7, # Smaller buckets
"stage3_param_persistence_threshold": 1e5, # Keep small params
# Issue: CPU offload OOM
# Solution: Pin memory and tune buffer size
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
"buffer_count": 4
}
Decision Matrix
| Scenario | Recommendation |
|---|---|
| 7B model, single node | FSDP (simpler) |
| 70B model, 8 GPUs | Either works well |
| 70B model, need offloading | DeepSpeed ZeRO-3 |
| Want torch.compile | FSDP |
| Multi-node, 100B+ | DeepSpeed ZeRO-3 |
| PyTorch ecosystem priority | FSDP |
| Maximum memory efficiency | DeepSpeed ZeRO-3 |
References
-
Zhao, Y., et al. (2023). "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel." arXiv:2304.11277
-
SimpleFSDP. (2024). "SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile." arXiv:2411.00284
-
Rajbhandari, S., et al. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models." arXiv:1910.02054
-
Rajbhandari, S., et al. (2021). "ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning." arXiv:2104.07857