Back to all articles
LLM Inference

KV Cache Explained: How Transformers Accelerate Autoregressive Generation

Deep dive into the KV cache mechanism in transformers. Learn how it works, memory requirements, optimization techniques like MQA/GQA, and paged attention implementations.

Flash Attention TeamJanuary 8, 20269 min read
KV cachetransformer inferenceautoregressive generationMQAGQApaged attention

The KV cache is a fundamental optimization that makes large language model inference practical. Without it, generating each token would require recomputing attention over the entire sequence. This guide explains exactly how it works and how to optimize it.

The Problem: Redundant Computation

In autoregressive generation, we predict one token at a time:

# Naive implementation (extremely slow)
def generate_naive(model, prompt_tokens, max_new_tokens):
    tokens = prompt_tokens

    for _ in range(max_new_tokens):
        # Full forward pass over ALL tokens for EACH new token
        logits = model(tokens)  # Cost: O(n²) per token
        next_token = sample(logits[-1])
        tokens.append(next_token)

    return tokens

For sequence length n, this costs O(n²) per token, O(n³) total.

How KV Cache Works

Attention Computation

Standard attention computes:

# For each position i
Q = X @ W_q  # Query
K = X @ W_k  # Key
V = X @ W_v  # Value

# Attention scores
scores = Q @ K.T / sqrt(d_k)
attention_weights = softmax(scores)
output = attention_weights @ V

The key insight: K and V for previous positions don't change when we add a new token.

Caching Keys and Values

class CachedAttention:
    def __init__(self):
        self.k_cache = []  # List of past keys
        self.v_cache = []  # List of past values

    def forward(self, x, use_cache=True):
        # Compute Q, K, V for current position only
        q = x @ self.W_q  # Shape: [1, d_k] (just the new token)
        k = x @ self.W_k
        v = x @ self.W_v

        if use_cache:
            # Append to cache
            self.k_cache.append(k)
            self.v_cache.append(v)

            # Attend to all cached keys/values
            K = torch.cat(self.k_cache, dim=0)  # [seq_len, d_k]
            V = torch.cat(self.v_cache, dim=0)  # [seq_len, d_v]
        else:
            K, V = k, v

        # Compute attention (new query against all keys)
        scores = q @ K.T / sqrt(self.d_k)  # [1, seq_len]
        weights = softmax(scores)
        output = weights @ V  # [1, d_v]

        return output

Complexity Improvement

PhaseWithout CacheWith Cache
Per tokenO(n²) attentionO(n) attention
Total generationO(n³)O(n²)
MemoryO(1)O(n) for cache

For 1000 tokens: 1000x faster with KV cache.

Memory Requirements

KV Cache Size Formula

KV Cache Memory = 2 × L × H × D × S × B × bytes_per_value

Where:
- 2: Keys and Values
- L: Number of layers
- H: Number of attention heads
- D: Head dimension (typically hidden_dim / num_heads)
- S: Sequence length
- B: Batch size
- bytes_per_value: 2 for FP16, 1 for INT8

Real-World Examples

ModelLayersHeadsHead Dim2K Context8K Context
LLaMA-7B32321281 GB4 GB
LLaMA-13B40401281.6 GB6.4 GB
LLaMA-70B80641285.2 GB20.8 GB

The KV cache often exceeds model weight memory for long sequences.

Attention Variants for Smaller Cache

Multi-Query Attention (MQA)

Share K and V across all heads, keep separate Q:

class MultiQueryAttention:
    def __init__(self, d_model, num_heads):
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Separate Q per head
        self.W_q = nn.Linear(d_model, d_model)

        # Single K, V shared across heads
        self.W_k = nn.Linear(d_model, self.head_dim)  # Much smaller!
        self.W_v = nn.Linear(d_model, self.head_dim)

    def forward(self, x):
        B, S, D = x.shape

        # Q: [B, S, num_heads, head_dim]
        Q = self.W_q(x).view(B, S, self.num_heads, self.head_dim)

        # K, V: [B, S, 1, head_dim] - shared across heads
        K = self.W_k(x).unsqueeze(2)
        V = self.W_v(x).unsqueeze(2)

        # Broadcast K, V to all heads
        # Attention: [B, num_heads, S, S]
        ...

Memory reduction: num_heads × (32x for LLaMA)

Grouped-Query Attention (GQA)

Middle ground—groups of heads share K, V:

class GroupedQueryAttention:
    def __init__(self, d_model, num_heads, num_kv_heads):
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.heads_per_group = num_heads // num_kv_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, num_kv_heads * head_dim)
        self.W_v = nn.Linear(d_model, num_kv_heads * head_dim)
ModelQ HeadsKV HeadsReduction
LLaMA-2 7B32321x (MHA)
LLaMA-2 70B6488x (GQA)
Mistral 7B3284x (GQA)
Falcon 40B64164x (MQA)

Quality Comparison

Attention TypeMemoryQuality vs MHA
MHA100%Baseline
GQA (8 KV heads)12.5%-0.5%
MQA (1 KV head)3%-2%

GQA provides the best quality-memory trade-off for large models.

Paged Attention (vLLM)

The Memory Fragmentation Problem

Traditional KV cache pre-allocates contiguous memory:

# Traditional approach
def allocate_cache(max_seq_len, hidden_dim):
    # Pre-allocate for maximum possible sequence
    return torch.zeros(max_seq_len, hidden_dim)

# Problem: Most sequences are shorter
# Wastes memory: allocated - actual_used

With variable-length requests:

  • Request 1: Needs 100 tokens, allocated 2048 → 97% waste
  • Request 2: Needs 500 tokens, allocated 2048 → 76% waste

Paged Attention Solution

Manage KV cache like operating system virtual memory:

class PagedKVCache:
    def __init__(self, block_size=16, num_blocks=1024):
        self.block_size = block_size
        # Pre-allocate blocks
        self.physical_blocks = [
            torch.zeros(block_size, hidden_dim)
            for _ in range(num_blocks)
        ]
        self.free_blocks = list(range(num_blocks))
        self.block_tables = {}  # seq_id -> [block_ids]

    def allocate(self, seq_id):
        """Allocate a new block for a sequence."""
        if not self.free_blocks:
            raise MemoryError("No free blocks")
        block_id = self.free_blocks.pop()
        if seq_id not in self.block_tables:
            self.block_tables[seq_id] = []
        self.block_tables[seq_id].append(block_id)
        return block_id

    def append_token(self, seq_id, k, v, position):
        """Add K, V for a new token."""
        block_idx = position // self.block_size
        offset = position % self.block_size

        # Allocate new block if needed
        while len(self.block_tables[seq_id]) <= block_idx:
            self.allocate(seq_id)

        physical_block = self.block_tables[seq_id][block_idx]
        self.physical_blocks[physical_block][offset] = torch.cat([k, v])

    def free(self, seq_id):
        """Release all blocks for a completed sequence."""
        for block_id in self.block_tables[seq_id]:
            self.free_blocks.append(block_id)
        del self.block_tables[seq_id]

Benefits of Paged Attention

  1. Near-zero waste: Only allocate blocks as needed
  2. Better batching: Fit more sequences in memory
  3. Copy-on-write: Share prefix blocks across sequences
# Copy-on-write for shared prefixes
# System prompt: "You are a helpful assistant..."
# All requests share the same prefix blocks

def share_prefix(source_seq, new_seq, prefix_len):
    """Share prefix blocks without copying."""
    prefix_blocks = prefix_len // block_size
    new_seq.block_table[:prefix_blocks] = source_seq.block_table[:prefix_blocks]
    new_seq.ref_counts[:prefix_blocks] += 1

# Only copy block when modified
def copy_on_write(seq, block_idx):
    if ref_counts[block_idx] > 1:
        new_block = allocate_block()
        copy(physical_blocks[new_block], physical_blocks[block_idx])
        seq.block_table[block_idx] = new_block
        ref_counts[block_idx] -= 1

Memory Efficiency Comparison

ApproachUtilizationMax Batch (A100 80GB)
Pre-allocated~20%8 sequences
Paged (vLLM)~90%40+ sequences

Implementation in PyTorch

Simple KV Cache

import torch
import torch.nn as nn

class KVCache:
    def __init__(self, max_batch_size, max_seq_len, num_layers, num_heads, head_dim):
        self.max_seq_len = max_seq_len
        # Pre-allocate cache
        self.k_cache = torch.zeros(
            num_layers, max_batch_size, max_seq_len, num_heads, head_dim
        )
        self.v_cache = torch.zeros(
            num_layers, max_batch_size, max_seq_len, num_heads, head_dim
        )
        self.seq_len = 0

    def update(self, layer_idx, k, v):
        """Update cache with new K, V values."""
        batch_size = k.shape[0]
        new_seq_len = k.shape[1]

        self.k_cache[layer_idx, :batch_size, self.seq_len:self.seq_len + new_seq_len] = k
        self.v_cache[layer_idx, :batch_size, self.seq_len:self.seq_len + new_seq_len] = v

        self.seq_len += new_seq_len

    def get(self, layer_idx, batch_size):
        """Retrieve cached K, V."""
        return (
            self.k_cache[layer_idx, :batch_size, :self.seq_len],
            self.v_cache[layer_idx, :batch_size, :self.seq_len]
        )

    def reset(self):
        self.seq_len = 0

Using with Hugging Face

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# Generate with KV cache (enabled by default)
inputs = tokenizer("Hello, world!", return_tensors="pt")

# First call: Prefill (no cache)
outputs = model.generate(
    **inputs,
    max_new_tokens=100,
    use_cache=True,  # Enable KV caching
    return_dict_in_generate=True,
)

# Access past key values
past_key_values = outputs.past_key_values
# Shape: (num_layers, 2, batch, num_heads, seq_len, head_dim)

Static KV Cache (PyTorch 2.0+)

# Static cache for torch.compile compatibility
from transformers import StaticCache

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
)

# Create static cache
cache = StaticCache(
    config=model.config,
    max_batch_size=4,
    max_cache_len=2048,
    dtype=torch.float16,
)

# Use with generate
outputs = model.generate(
    input_ids,
    past_key_values=cache,
    max_new_tokens=100,
)

Optimization Tips

1. Right-size Your Cache

# Don't over-allocate
actual_max_len = max(len(prompt) for prompt in prompts) + max_new_tokens
cache = create_cache(max_seq_len=actual_max_len)  # Not 32K if you need 2K

2. Use Appropriate Precision

# FP16 halves KV cache memory
cache = cache.half()

# INT8 KV cache (experimental)
# Some frameworks support quantized KV cache

3. Sliding Window Attention

For very long sequences, use sliding window (Mistral-style):

# Only cache last window_size tokens
window_size = 4096
if seq_len > window_size:
    k_cache = k_cache[:, -window_size:]
    v_cache = v_cache[:, -window_size:]

4. Speculative Decoding Cache Management

# When using speculative decoding, manage cache for draft + verify
def speculative_step(draft_model, target_model, cache, k=4):
    # Draft tokens with draft model cache
    draft_tokens, draft_cache = draft_model.generate(k_tokens)

    # Verify with target model (use draft cache + verify)
    accepted, target_cache = target_model.verify(draft_tokens, draft_cache)

    # Roll back cache to last accepted position
    cache.truncate(len(accepted))

References

  1. Vaswani, A., et al. (2017). "Attention Is All You Need." NeurIPS 2017

  2. Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150

  3. Ainslie, J., et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." arXiv:2305.13245

  4. Kwon, W., et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP 2023

  5. PagedEviction. (2025). "PagedEviction: Structured Block-wise KV Cache Pruning." arXiv:2509.04377

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel