Attention is the core mechanism that makes transformers work. This guide explains attention from first principles, building up to modern variants used in production LLMs.
The Attention Mechanism
Intuition
Attention answers: "Which parts of the input should I focus on to process each position?"
Input: "The cat sat on the mat"
When processing "sat":
- High attention to "cat" (subject)
- Medium attention to "mat" (object)
- Low attention to "the" (less relevant)
Scaled Dot-Product Attention
The fundamental attention computation:
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: Queries [batch, seq_len, d_k]
K: Keys [batch, seq_len, d_k]
V: Values [batch, seq_len, d_v]
"""
d_k = Q.shape[-1]
# Step 1: Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) # [batch, seq, seq]
# Step 2: Scale by sqrt(d_k) for stable gradients
scores = scores / math.sqrt(d_k)
# Step 3: Apply mask (for causal attention)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Step 4: Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1) # [batch, seq, seq]
# Step 5: Apply attention to values
output = torch.matmul(attention_weights, V) # [batch, seq, d_v]
return output, attention_weights
Why Scale by √d_k?
Without scaling, dot products grow with dimension, pushing softmax into saturated regions:
# Without scaling (d_k=512)
dot_product = sum(q_i * k_i for i in range(512))
# If q, k ~ N(0,1), dot_product ~ N(0, 512)
# Large values → softmax approaches one-hot → vanishing gradients
# With scaling
scaled = dot_product / sqrt(512) # ~ N(0, 1)
# Softmax stays in "useful" range
Query, Key, Value
Creating Q, K, V
class AttentionProjections(nn.Module):
def __init__(self, d_model, d_k, d_v):
super().__init__()
self.W_q = nn.Linear(d_model, d_k)
self.W_k = nn.Linear(d_model, d_k)
self.W_v = nn.Linear(d_model, d_v)
def forward(self, x):
Q = self.W_q(x) # What am I looking for?
K = self.W_k(x) # What do I contain?
V = self.W_v(x) # What information do I provide?
return Q, K, V
Intuitive Understanding
| Component | Analogy | Purpose |
|---|---|---|
| Query | Question | "What information do I need?" |
| Key | Index | "What information do I have?" |
| Value | Content | "Here's my actual information" |
Multi-Head Attention
Why Multiple Heads?
Different heads can attend to different aspects:
- Head 1: Syntactic relationships
- Head 2: Semantic similarity
- Head 3: Positional patterns
Implementation
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, d_model = x.shape
# Project and reshape to [batch, heads, seq, d_k]
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Scaled dot-product attention per head
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention = F.softmax(scores, dim=-1)
out = torch.matmul(attention, V)
# Concatenate heads and project
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
return self.W_o(out)
Visualization
Input: [batch, seq, d_model=512]
↓
Split into 8 heads: [batch, 8, seq, 64]
↓
Each head computes attention independently
↓
Concatenate: [batch, seq, 512]
↓
Output projection: [batch, seq, 512]
Self-Attention vs Cross-Attention
Self-Attention
Q, K, V all come from the same sequence:
# Self-attention: sequence attends to itself
Q = K = V = encoder_output
output = attention(Q, K, V)
Used in: Encoder layers, decoder self-attention
Cross-Attention
Q from one sequence, K, V from another:
# Cross-attention: decoder attends to encoder
Q = decoder_state
K = V = encoder_output
output = attention(Q, K, V)
Used in: Decoder attending to encoder in seq2seq models
Causal (Masked) Attention
For autoregressive models, prevent attending to future tokens:
def causal_mask(seq_len):
"""Create lower triangular mask."""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask
# Effect:
# Position 0 can attend to: [0]
# Position 1 can attend to: [0, 1]
# Position 2 can attend to: [0, 1, 2]
# ...
Modern Attention Variants
Multi-Query Attention (MQA)
Single K, V shared across all heads:
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model) # Full size
self.W_k = nn.Linear(d_model, self.d_k) # Single head!
self.W_v = nn.Linear(d_model, self.d_k) # Single head!
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
Q = self.W_q(x) # [batch, seq, d_model]
K = self.W_k(x) # [batch, seq, d_k]
V = self.W_v(x) # [batch, seq, d_k]
# Broadcast K, V to all heads
Q = Q.view(batch, seq, self.num_heads, self.d_k).transpose(1, 2)
K = K.unsqueeze(1) # [batch, 1, seq, d_k]
V = V.unsqueeze(1) # [batch, 1, seq, d_k]
# K, V broadcast to all heads during attention
...
Benefit: Much smaller KV cache (1/num_heads)
Grouped-Query Attention (GQA)
Groups of heads share K, V:
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_kv_heads):
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.heads_per_group = num_heads // num_kv_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
Used in: LLaMA 2 70B (8 KV heads for 64 Q heads)
Sliding Window Attention
Limit attention to local context (Mistral):
def sliding_window_mask(seq_len, window_size):
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - window_size)
mask[i, start:i+1] = 1
return mask
Benefit: O(n × window) instead of O(n²)
Attention Complexity
| Variant | Time | Memory | KV Cache |
|---|---|---|---|
| Standard MHA | O(n²d) | O(n²) | O(n × h × d) |
| Flash Attention | O(n²d) | O(n) | O(n × h × d) |
| MQA | O(n²d) | O(n²) | O(n × d) |
| GQA | O(n²d) | O(n²) | O(n × g × d) |
| Sliding Window | O(n × w × d) | O(n × w) | O(w × h × d) |
PyTorch Native Implementation
# PyTorch 2.0+ scaled_dot_product_attention
import torch.nn.functional as F
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=mask,
dropout_p=0.1,
is_causal=True, # Efficient causal mask
)
# Automatically uses Flash Attention if available
References
-
Vaswani, A., et al. (2017). "Attention Is All You Need." NeurIPS 2017
-
Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150
-
Ainslie, J., et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models." arXiv:2305.13245
-
Cost-Optimal GQA. (2025). "Cost-Optimal Grouped-Query Attention for Long-Context Modeling." EMNLP 2025