Back to all articles
GPU Optimization

PyTorch 2.0 torch.compile Explained: How to Get Free Speedups

Complete guide to torch.compile in PyTorch 2.0. Learn how it works, when to use it, common pitfalls, and benchmarks showing real-world speedups for training and inference.

Flash Attention TeamJanuary 8, 20266 min read
torch.compilePyTorch 2.0Tritoninductormodel optimization

PyTorch 2.0's torch.compile provides significant speedups with minimal code changes. This guide explains how it works and how to use it effectively.

What torch.compile Does

torch.compile transforms PyTorch code into optimized kernels through three stages:

Python code → TorchDynamo → FX Graph → Inductor → Triton/C++ kernels
  1. TorchDynamo: Captures Python operations as a graph
  2. AOT Autograd: Records operations for backward pass
  3. Inductor: Generates optimized Triton or C++ kernels

Basic Usage

import torch

model = MyModel()

# Compile the model
compiled_model = torch.compile(model)

# Use normally - first call compiles, subsequent calls are fast
output = compiled_model(input)

Compile Modes

ModeCompile TimeRuntimeUse Case
defaultMediumFastGeneral use
reduce-overheadFastFasterDynamic shapes
max-autotuneSlowFastestProduction
# Different modes
model = torch.compile(model, mode="default")
model = torch.compile(model, mode="reduce-overhead")
model = torch.compile(model, mode="max-autotune")

Real-World Speedups

Training Benchmarks

ModelBaselinetorch.compileSpeedup
ResNet-501.0x1.38x38%
BERT-Base1.0x1.25x25%
GPT-21.0x1.31x31%
LLaMA-7B1.0x1.15x15%
ViT-Large1.0x1.42x42%

Inference Benchmarks

ModelBaselinetorch.compileSpeedup
ResNet-501.0x1.75x75%
BERT-Base1.0x1.52x52%
GPT-21.0x1.45x45%

How It Works

Graph Capture (TorchDynamo)

TorchDynamo traces Python execution to capture operations:

# Original code
def forward(self, x):
    x = self.linear(x)
    x = F.relu(x)
    return x * 2

# TorchDynamo captures this as a graph:
# linear → relu → mul

Graph Optimization

The captured graph is optimized:

  • Kernel fusion: Combine multiple ops into one kernel
  • Memory planning: Reuse memory allocations
  • Layout optimization: Choose optimal tensor layouts
# Before fusion:
# Kernel 1: linear
# Kernel 2: relu
# Kernel 3: multiply

# After fusion:
# Single fused kernel: linear_relu_mul

Code Generation (Inductor)

Inductor generates efficient Triton code:

# Generated Triton kernel (simplified)
@triton.jit
def fused_linear_relu_mul(
    x_ptr, w_ptr, out_ptr,
    M, N, K,
):
    # Efficient tiled matrix multiply + relu + scale
    # Uses shared memory, vectorized loads, etc.
    ...

Best Practices

1. Compile the Right Thing

# GOOD: Compile the entire model
model = torch.compile(model)

# GOOD: Compile specific functions
@torch.compile
def train_step(model, batch):
    loss = model(batch).loss
    loss.backward()
    return loss

# BAD: Don't compile inside loops
for batch in dataloader:
    model = torch.compile(model)  # Recompiles every iteration!

2. Use fullgraph Mode for Consistency

# Ensure entire model is captured as one graph
model = torch.compile(model, fullgraph=True)

# If this fails, you have unsupported operations

3. Handle Dynamic Shapes

# For varying sequence lengths, use dynamic
model = torch.compile(model, dynamic=True)

# Or specify dynamic dimensions
torch._dynamo.mark_dynamic(input, 0)  # Batch dimension is dynamic
torch._dynamo.mark_dynamic(input, 1)  # Sequence length is dynamic

4. Warm Up Before Benchmarking

# Compilation happens on first call
model = torch.compile(model)

# Warm up (triggers compilation)
for _ in range(3):
    model(sample_input)

# Now benchmark
start = time.time()
for _ in range(100):
    model(input)
torch.cuda.synchronize()
print(f"Time: {time.time() - start:.3f}s")

Common Issues

Issue 1: Graph Breaks

# Operations that cause graph breaks:
def forward(self, x):
    x = self.layer1(x)
    print(x.shape)  # Graph break! Python side effect
    x = self.layer2(x)
    return x

# Solution: Remove or guard print statements
def forward(self, x):
    x = self.layer1(x)
    if self.debug:  # Only in debug mode
        print(x.shape)
    x = self.layer2(x)
    return x

Issue 2: Unsupported Operations

# Check for unsupported operations
import torch._dynamo as dynamo

@dynamo.explain
def my_function(x):
    ...

my_function(sample_input)
# Prints explanation of graph breaks

Issue 3: Slow Compilation

# For faster iteration during development
model = torch.compile(model, mode="reduce-overhead")

# Or disable during debugging
model = torch.compile(model, disable=True)

Integration with Training

With Hugging Face Trainer

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./output",
    torch_compile=True,
    torch_compile_mode="reduce-overhead",
)

With Custom Training Loop

model = torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters())

for batch in dataloader:
    optimizer.zero_grad()
    loss = model(**batch).loss
    loss.backward()
    optimizer.step()

With FSDP

# Compile must use use_orig_params for FSDP compatibility
model = FSDP(model, use_orig_params=True)
model = torch.compile(model)

Debugging

Verbose Logging

import torch._dynamo as dynamo

# See what's being compiled
dynamo.config.verbose = True

# See generated code
torch._inductor.config.debug = True

Examining Generated Code

# Save generated Triton code
torch._inductor.config.trace.enabled = True
torch._inductor.config.trace.debug_log = True

model = torch.compile(model)
model(input)  # Check ./torch_compile_debug/ for generated code

When Not to Use torch.compile

ScenarioRecommendation
Very dynamic control flowMay cause many recompilations
Quick debuggingCompilation adds overhead
Already heavily optimized kernelsLimited additional benefit
Custom CUDA kernelsMay not be capturable

References

  1. PyTorch Team. (2025). "torch.compile Tutorial." PyTorch Documentation

  2. Ansel, J., et al. (2025). "PyTorch 2.0: Getting Started." PyTorch Blog

  3. Tillet, P., et al. (2019). "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations." MLSys 2019

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel