PyTorch 2.0's torch.compile provides significant speedups with minimal code changes. This guide explains how it works and how to use it effectively.
What torch.compile Does
torch.compile transforms PyTorch code into optimized kernels through three stages:
Python code → TorchDynamo → FX Graph → Inductor → Triton/C++ kernels
- TorchDynamo: Captures Python operations as a graph
- AOT Autograd: Records operations for backward pass
- Inductor: Generates optimized Triton or C++ kernels
Basic Usage
import torch
model = MyModel()
# Compile the model
compiled_model = torch.compile(model)
# Use normally - first call compiles, subsequent calls are fast
output = compiled_model(input)
Compile Modes
| Mode | Compile Time | Runtime | Use Case |
|---|---|---|---|
default | Medium | Fast | General use |
reduce-overhead | Fast | Faster | Dynamic shapes |
max-autotune | Slow | Fastest | Production |
# Different modes
model = torch.compile(model, mode="default")
model = torch.compile(model, mode="reduce-overhead")
model = torch.compile(model, mode="max-autotune")
Real-World Speedups
Training Benchmarks
| Model | Baseline | torch.compile | Speedup |
|---|---|---|---|
| ResNet-50 | 1.0x | 1.38x | 38% |
| BERT-Base | 1.0x | 1.25x | 25% |
| GPT-2 | 1.0x | 1.31x | 31% |
| LLaMA-7B | 1.0x | 1.15x | 15% |
| ViT-Large | 1.0x | 1.42x | 42% |
Inference Benchmarks
| Model | Baseline | torch.compile | Speedup |
|---|---|---|---|
| ResNet-50 | 1.0x | 1.75x | 75% |
| BERT-Base | 1.0x | 1.52x | 52% |
| GPT-2 | 1.0x | 1.45x | 45% |
How It Works
Graph Capture (TorchDynamo)
TorchDynamo traces Python execution to capture operations:
# Original code
def forward(self, x):
x = self.linear(x)
x = F.relu(x)
return x * 2
# TorchDynamo captures this as a graph:
# linear → relu → mul
Graph Optimization
The captured graph is optimized:
- Kernel fusion: Combine multiple ops into one kernel
- Memory planning: Reuse memory allocations
- Layout optimization: Choose optimal tensor layouts
# Before fusion:
# Kernel 1: linear
# Kernel 2: relu
# Kernel 3: multiply
# After fusion:
# Single fused kernel: linear_relu_mul
Code Generation (Inductor)
Inductor generates efficient Triton code:
# Generated Triton kernel (simplified)
@triton.jit
def fused_linear_relu_mul(
x_ptr, w_ptr, out_ptr,
M, N, K,
):
# Efficient tiled matrix multiply + relu + scale
# Uses shared memory, vectorized loads, etc.
...
Best Practices
1. Compile the Right Thing
# GOOD: Compile the entire model
model = torch.compile(model)
# GOOD: Compile specific functions
@torch.compile
def train_step(model, batch):
loss = model(batch).loss
loss.backward()
return loss
# BAD: Don't compile inside loops
for batch in dataloader:
model = torch.compile(model) # Recompiles every iteration!
2. Use fullgraph Mode for Consistency
# Ensure entire model is captured as one graph
model = torch.compile(model, fullgraph=True)
# If this fails, you have unsupported operations
3. Handle Dynamic Shapes
# For varying sequence lengths, use dynamic
model = torch.compile(model, dynamic=True)
# Or specify dynamic dimensions
torch._dynamo.mark_dynamic(input, 0) # Batch dimension is dynamic
torch._dynamo.mark_dynamic(input, 1) # Sequence length is dynamic
4. Warm Up Before Benchmarking
# Compilation happens on first call
model = torch.compile(model)
# Warm up (triggers compilation)
for _ in range(3):
model(sample_input)
# Now benchmark
start = time.time()
for _ in range(100):
model(input)
torch.cuda.synchronize()
print(f"Time: {time.time() - start:.3f}s")
Common Issues
Issue 1: Graph Breaks
# Operations that cause graph breaks:
def forward(self, x):
x = self.layer1(x)
print(x.shape) # Graph break! Python side effect
x = self.layer2(x)
return x
# Solution: Remove or guard print statements
def forward(self, x):
x = self.layer1(x)
if self.debug: # Only in debug mode
print(x.shape)
x = self.layer2(x)
return x
Issue 2: Unsupported Operations
# Check for unsupported operations
import torch._dynamo as dynamo
@dynamo.explain
def my_function(x):
...
my_function(sample_input)
# Prints explanation of graph breaks
Issue 3: Slow Compilation
# For faster iteration during development
model = torch.compile(model, mode="reduce-overhead")
# Or disable during debugging
model = torch.compile(model, disable=True)
Integration with Training
With Hugging Face Trainer
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./output",
torch_compile=True,
torch_compile_mode="reduce-overhead",
)
With Custom Training Loop
model = torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters())
for batch in dataloader:
optimizer.zero_grad()
loss = model(**batch).loss
loss.backward()
optimizer.step()
With FSDP
# Compile must use use_orig_params for FSDP compatibility
model = FSDP(model, use_orig_params=True)
model = torch.compile(model)
Debugging
Verbose Logging
import torch._dynamo as dynamo
# See what's being compiled
dynamo.config.verbose = True
# See generated code
torch._inductor.config.debug = True
Examining Generated Code
# Save generated Triton code
torch._inductor.config.trace.enabled = True
torch._inductor.config.trace.debug_log = True
model = torch.compile(model)
model(input) # Check ./torch_compile_debug/ for generated code
When Not to Use torch.compile
| Scenario | Recommendation |
|---|---|
| Very dynamic control flow | May cause many recompilations |
| Quick debugging | Compilation adds overhead |
| Already heavily optimized kernels | Limited additional benefit |
| Custom CUDA kernels | May not be capturable |
References
-
PyTorch Team. (2025). "torch.compile Tutorial." PyTorch Documentation
-
Ansel, J., et al. (2025). "PyTorch 2.0: Getting Started." PyTorch Blog
-
Tillet, P., et al. (2019). "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations." MLSys 2019