The KV cache is a fundamental optimization that makes large language model inference practical. Without it, generating each token would require recomputing attention over the entire sequence. This guide explains exactly how it works and how to optimize it.
The Problem: Redundant Computation
In autoregressive generation, we predict one token at a time:
# Naive implementation (extremely slow)
def generate_naive(model, prompt_tokens, max_new_tokens):
tokens = prompt_tokens
for _ in range(max_new_tokens):
# Full forward pass over ALL tokens for EACH new token
logits = model(tokens) # Cost: O(n²) per token
next_token = sample(logits[-1])
tokens.append(next_token)
return tokens
For sequence length n, this costs O(n²) per token, O(n³) total.
How KV Cache Works
Attention Computation
Standard attention computes:
# For each position i
Q = X @ W_q # Query
K = X @ W_k # Key
V = X @ W_v # Value
# Attention scores
scores = Q @ K.T / sqrt(d_k)
attention_weights = softmax(scores)
output = attention_weights @ V
The key insight: K and V for previous positions don't change when we add a new token.
Caching Keys and Values
class CachedAttention:
def __init__(self):
self.k_cache = [] # List of past keys
self.v_cache = [] # List of past values
def forward(self, x, use_cache=True):
# Compute Q, K, V for current position only
q = x @ self.W_q # Shape: [1, d_k] (just the new token)
k = x @ self.W_k
v = x @ self.W_v
if use_cache:
# Append to cache
self.k_cache.append(k)
self.v_cache.append(v)
# Attend to all cached keys/values
K = torch.cat(self.k_cache, dim=0) # [seq_len, d_k]
V = torch.cat(self.v_cache, dim=0) # [seq_len, d_v]
else:
K, V = k, v
# Compute attention (new query against all keys)
scores = q @ K.T / sqrt(self.d_k) # [1, seq_len]
weights = softmax(scores)
output = weights @ V # [1, d_v]
return output
Complexity Improvement
| Phase | Without Cache | With Cache |
|---|---|---|
| Per token | O(n²) attention | O(n) attention |
| Total generation | O(n³) | O(n²) |
| Memory | O(1) | O(n) for cache |
For 1000 tokens: 1000x faster with KV cache.
Memory Requirements
KV Cache Size Formula
KV Cache Memory = 2 × L × H × D × S × B × bytes_per_value
Where:
- 2: Keys and Values
- L: Number of layers
- H: Number of attention heads
- D: Head dimension (typically hidden_dim / num_heads)
- S: Sequence length
- B: Batch size
- bytes_per_value: 2 for FP16, 1 for INT8
Real-World Examples
| Model | Layers | Heads | Head Dim | 2K Context | 8K Context |
|---|---|---|---|---|---|
| LLaMA-7B | 32 | 32 | 128 | 1 GB | 4 GB |
| LLaMA-13B | 40 | 40 | 128 | 1.6 GB | 6.4 GB |
| LLaMA-70B | 80 | 64 | 128 | 5.2 GB | 20.8 GB |
The KV cache often exceeds model weight memory for long sequences.
Attention Variants for Smaller Cache
Multi-Query Attention (MQA)
Share K and V across all heads, keep separate Q:
class MultiQueryAttention:
def __init__(self, d_model, num_heads):
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Separate Q per head
self.W_q = nn.Linear(d_model, d_model)
# Single K, V shared across heads
self.W_k = nn.Linear(d_model, self.head_dim) # Much smaller!
self.W_v = nn.Linear(d_model, self.head_dim)
def forward(self, x):
B, S, D = x.shape
# Q: [B, S, num_heads, head_dim]
Q = self.W_q(x).view(B, S, self.num_heads, self.head_dim)
# K, V: [B, S, 1, head_dim] - shared across heads
K = self.W_k(x).unsqueeze(2)
V = self.W_v(x).unsqueeze(2)
# Broadcast K, V to all heads
# Attention: [B, num_heads, S, S]
...
Memory reduction: num_heads × (32x for LLaMA)
Grouped-Query Attention (GQA)
Middle ground—groups of heads share K, V:
class GroupedQueryAttention:
def __init__(self, d_model, num_heads, num_kv_heads):
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.heads_per_group = num_heads // num_kv_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, num_kv_heads * head_dim)
self.W_v = nn.Linear(d_model, num_kv_heads * head_dim)
| Model | Q Heads | KV Heads | Reduction |
|---|---|---|---|
| LLaMA-2 7B | 32 | 32 | 1x (MHA) |
| LLaMA-2 70B | 64 | 8 | 8x (GQA) |
| Mistral 7B | 32 | 8 | 4x (GQA) |
| Falcon 40B | 64 | 1 | 64x (MQA) |
Quality Comparison
| Attention Type | Memory | Quality vs MHA |
|---|---|---|
| MHA | 100% | Baseline |
| GQA (8 KV heads) | 12.5% | -0.5% |
| MQA (1 KV head) | 3% | -2% |
GQA provides the best quality-memory trade-off for large models.
Paged Attention (vLLM)
The Memory Fragmentation Problem
Traditional KV cache pre-allocates contiguous memory:
# Traditional approach
def allocate_cache(max_seq_len, hidden_dim):
# Pre-allocate for maximum possible sequence
return torch.zeros(max_seq_len, hidden_dim)
# Problem: Most sequences are shorter
# Wastes memory: allocated - actual_used
With variable-length requests:
- Request 1: Needs 100 tokens, allocated 2048 → 97% waste
- Request 2: Needs 500 tokens, allocated 2048 → 76% waste
Paged Attention Solution
Manage KV cache like operating system virtual memory:
class PagedKVCache:
def __init__(self, block_size=16, num_blocks=1024):
self.block_size = block_size
# Pre-allocate blocks
self.physical_blocks = [
torch.zeros(block_size, hidden_dim)
for _ in range(num_blocks)
]
self.free_blocks = list(range(num_blocks))
self.block_tables = {} # seq_id -> [block_ids]
def allocate(self, seq_id):
"""Allocate a new block for a sequence."""
if not self.free_blocks:
raise MemoryError("No free blocks")
block_id = self.free_blocks.pop()
if seq_id not in self.block_tables:
self.block_tables[seq_id] = []
self.block_tables[seq_id].append(block_id)
return block_id
def append_token(self, seq_id, k, v, position):
"""Add K, V for a new token."""
block_idx = position // self.block_size
offset = position % self.block_size
# Allocate new block if needed
while len(self.block_tables[seq_id]) <= block_idx:
self.allocate(seq_id)
physical_block = self.block_tables[seq_id][block_idx]
self.physical_blocks[physical_block][offset] = torch.cat([k, v])
def free(self, seq_id):
"""Release all blocks for a completed sequence."""
for block_id in self.block_tables[seq_id]:
self.free_blocks.append(block_id)
del self.block_tables[seq_id]
Benefits of Paged Attention
- Near-zero waste: Only allocate blocks as needed
- Better batching: Fit more sequences in memory
- Copy-on-write: Share prefix blocks across sequences
# Copy-on-write for shared prefixes
# System prompt: "You are a helpful assistant..."
# All requests share the same prefix blocks
def share_prefix(source_seq, new_seq, prefix_len):
"""Share prefix blocks without copying."""
prefix_blocks = prefix_len // block_size
new_seq.block_table[:prefix_blocks] = source_seq.block_table[:prefix_blocks]
new_seq.ref_counts[:prefix_blocks] += 1
# Only copy block when modified
def copy_on_write(seq, block_idx):
if ref_counts[block_idx] > 1:
new_block = allocate_block()
copy(physical_blocks[new_block], physical_blocks[block_idx])
seq.block_table[block_idx] = new_block
ref_counts[block_idx] -= 1
Memory Efficiency Comparison
| Approach | Utilization | Max Batch (A100 80GB) |
|---|---|---|
| Pre-allocated | ~20% | 8 sequences |
| Paged (vLLM) | ~90% | 40+ sequences |
Implementation in PyTorch
Simple KV Cache
import torch
import torch.nn as nn
class KVCache:
def __init__(self, max_batch_size, max_seq_len, num_layers, num_heads, head_dim):
self.max_seq_len = max_seq_len
# Pre-allocate cache
self.k_cache = torch.zeros(
num_layers, max_batch_size, max_seq_len, num_heads, head_dim
)
self.v_cache = torch.zeros(
num_layers, max_batch_size, max_seq_len, num_heads, head_dim
)
self.seq_len = 0
def update(self, layer_idx, k, v):
"""Update cache with new K, V values."""
batch_size = k.shape[0]
new_seq_len = k.shape[1]
self.k_cache[layer_idx, :batch_size, self.seq_len:self.seq_len + new_seq_len] = k
self.v_cache[layer_idx, :batch_size, self.seq_len:self.seq_len + new_seq_len] = v
self.seq_len += new_seq_len
def get(self, layer_idx, batch_size):
"""Retrieve cached K, V."""
return (
self.k_cache[layer_idx, :batch_size, :self.seq_len],
self.v_cache[layer_idx, :batch_size, :self.seq_len]
)
def reset(self):
self.seq_len = 0
Using with Hugging Face
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Generate with KV cache (enabled by default)
inputs = tokenizer("Hello, world!", return_tensors="pt")
# First call: Prefill (no cache)
outputs = model.generate(
**inputs,
max_new_tokens=100,
use_cache=True, # Enable KV caching
return_dict_in_generate=True,
)
# Access past key values
past_key_values = outputs.past_key_values
# Shape: (num_layers, 2, batch, num_heads, seq_len, head_dim)
Static KV Cache (PyTorch 2.0+)
# Static cache for torch.compile compatibility
from transformers import StaticCache
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
)
# Create static cache
cache = StaticCache(
config=model.config,
max_batch_size=4,
max_cache_len=2048,
dtype=torch.float16,
)
# Use with generate
outputs = model.generate(
input_ids,
past_key_values=cache,
max_new_tokens=100,
)
Optimization Tips
1. Right-size Your Cache
# Don't over-allocate
actual_max_len = max(len(prompt) for prompt in prompts) + max_new_tokens
cache = create_cache(max_seq_len=actual_max_len) # Not 32K if you need 2K
2. Use Appropriate Precision
# FP16 halves KV cache memory
cache = cache.half()
# INT8 KV cache (experimental)
# Some frameworks support quantized KV cache
3. Sliding Window Attention
For very long sequences, use sliding window (Mistral-style):
# Only cache last window_size tokens
window_size = 4096
if seq_len > window_size:
k_cache = k_cache[:, -window_size:]
v_cache = v_cache[:, -window_size:]
4. Speculative Decoding Cache Management
# When using speculative decoding, manage cache for draft + verify
def speculative_step(draft_model, target_model, cache, k=4):
# Draft tokens with draft model cache
draft_tokens, draft_cache = draft_model.generate(k_tokens)
# Verify with target model (use draft cache + verify)
accepted, target_cache = target_model.verify(draft_tokens, draft_cache)
# Roll back cache to last accepted position
cache.truncate(len(accepted))
References
-
Vaswani, A., et al. (2017). "Attention Is All You Need." NeurIPS 2017
-
Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." arXiv:1911.02150
-
Ainslie, J., et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." arXiv:2305.13245
-
Kwon, W., et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP 2023
-
PagedEviction. (2025). "PagedEviction: Structured Block-wise KV Cache Pruning." arXiv:2509.04377