Back to all articles
Deep Learning

Understanding Transformer Attention: Scaled Dot-Product to Multi-Head

Complete guide to transformer attention mechanisms. Learn scaled dot-product attention, multi-head attention, and modern variants like MQA and GQA with visual explanations and PyTorch code.

Flash Attention TeamJanuary 8, 20267 min read
transformerattention mechanismmulti-head attentionself-attentiondeep learning

Attention is the core mechanism that makes transformers work. This guide explains attention from first principles, building up to modern variants used in production LLMs.

The Attention Mechanism

Intuition

Attention answers: "Which parts of the input should I focus on to process each position?"

Input: "The cat sat on the mat"

When processing "sat":
- High attention to "cat" (subject)
- Medium attention to "mat" (object)
- Low attention to "the" (less relevant)

Scaled Dot-Product Attention

The fundamental attention computation:

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: Queries [batch, seq_len, d_k]
    K: Keys [batch, seq_len, d_k]
    V: Values [batch, seq_len, d_v]
    """
    d_k = Q.shape[-1]

    # Step 1: Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))  # [batch, seq, seq]

    # Step 2: Scale by sqrt(d_k) for stable gradients
    scores = scores / math.sqrt(d_k)

    # Step 3: Apply mask (for causal attention)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # Step 4: Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)  # [batch, seq, seq]

    # Step 5: Apply attention to values
    output = torch.matmul(attention_weights, V)  # [batch, seq, d_v]

    return output, attention_weights

Why Scale by √d_k?

Without scaling, dot products grow with dimension, pushing softmax into saturated regions:

# Without scaling (d_k=512)
dot_product = sum(q_i * k_i for i in range(512))
# If q, k ~ N(0,1), dot_product ~ N(0, 512)
# Large values → softmax approaches one-hot → vanishing gradients

# With scaling
scaled = dot_product / sqrt(512)  # ~ N(0, 1)
# Softmax stays in "useful" range

Query, Key, Value

Creating Q, K, V

class AttentionProjections(nn.Module):
    def __init__(self, d_model, d_k, d_v):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_k)
        self.W_k = nn.Linear(d_model, d_k)
        self.W_v = nn.Linear(d_model, d_v)

    def forward(self, x):
        Q = self.W_q(x)  # What am I looking for?
        K = self.W_k(x)  # What do I contain?
        V = self.W_v(x)  # What information do I provide?
        return Q, K, V

Intuitive Understanding

ComponentAnalogyPurpose
QueryQuestion"What information do I need?"
KeyIndex"What information do I have?"
ValueContent"Here's my actual information"

Multi-Head Attention

Why Multiple Heads?

Different heads can attend to different aspects:

  • Head 1: Syntactic relationships
  • Head 2: Semantic similarity
  • Head 3: Positional patterns

Implementation

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape

        # Project and reshape to [batch, heads, seq, d_k]
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention per head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention = F.softmax(scores, dim=-1)
        out = torch.matmul(attention, V)

        # Concatenate heads and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        return self.W_o(out)

Visualization

Input: [batch, seq, d_model=512]
            ↓
Split into 8 heads: [batch, 8, seq, 64]
            ↓
Each head computes attention independently
            ↓
Concatenate: [batch, seq, 512]
            ↓
Output projection: [batch, seq, 512]

Self-Attention vs Cross-Attention

Self-Attention

Q, K, V all come from the same sequence:

# Self-attention: sequence attends to itself
Q = K = V = encoder_output
output = attention(Q, K, V)

Used in: Encoder layers, decoder self-attention

Cross-Attention

Q from one sequence, K, V from another:

# Cross-attention: decoder attends to encoder
Q = decoder_state
K = V = encoder_output
output = attention(Q, K, V)

Used in: Decoder attending to encoder in seq2seq models

Causal (Masked) Attention

For autoregressive models, prevent attending to future tokens:

def causal_mask(seq_len):
    """Create lower triangular mask."""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# Effect:
# Position 0 can attend to: [0]
# Position 1 can attend to: [0, 1]
# Position 2 can attend to: [0, 1, 2]
# ...

Modern Attention Variants

Multi-Query Attention (MQA)

Single K, V shared across all heads:

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)  # Full size
        self.W_k = nn.Linear(d_model, self.d_k)  # Single head!
        self.W_v = nn.Linear(d_model, self.d_k)  # Single head!
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        Q = self.W_q(x)  # [batch, seq, d_model]
        K = self.W_k(x)  # [batch, seq, d_k]
        V = self.W_v(x)  # [batch, seq, d_k]

        # Broadcast K, V to all heads
        Q = Q.view(batch, seq, self.num_heads, self.d_k).transpose(1, 2)
        K = K.unsqueeze(1)  # [batch, 1, seq, d_k]
        V = V.unsqueeze(1)  # [batch, 1, seq, d_k]

        # K, V broadcast to all heads during attention
        ...

Benefit: Much smaller KV cache (1/num_heads)

Grouped-Query Attention (GQA)

Groups of heads share K, V:

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_kv_heads):
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.heads_per_group = num_heads // num_kv_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)

Used in: LLaMA 2 70B (8 KV heads for 64 Q heads)

Sliding Window Attention

Limit attention to local context (Mistral):

def sliding_window_mask(seq_len, window_size):
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size)
        mask[i, start:i+1] = 1
    return mask

Benefit: O(n × window) instead of O(n²)

Attention Complexity

VariantTimeMemoryKV Cache
Standard MHAO(n²d)O(n²)O(n × h × d)
Flash AttentionO(n²d)O(n)O(n × h × d)
MQAO(n²d)O(n²)O(n × d)
GQAO(n²d)O(n²)O(n × g × d)
Sliding WindowO(n × w × d)O(n × w)O(w × h × d)

PyTorch Native Implementation

# PyTorch 2.0+ scaled_dot_product_attention
import torch.nn.functional as F

output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=mask,
    dropout_p=0.1,
    is_causal=True,  # Efficient causal mask
)

# Automatically uses Flash Attention if available

References

  1. Vaswani, A., et al. (2017). "Attention Is All You Need." NeurIPS 2017

  2. Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150

  3. Ainslie, J., et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models." arXiv:2305.13245

  4. Cost-Optimal GQA. (2025). "Cost-Optimal Grouped-Query Attention for Long-Context Modeling." EMNLP 2025

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel