Back to all articles
LLM Inference

LLM Inference Optimization: From Naive to Production-Ready

Complete guide to optimizing LLM inference for production. Covers KV caching, quantization, batching strategies, speculative decoding, and serving frameworks with benchmarks.

Flash Attention TeamJanuary 8, 202611 min read
LLM inferenceinference optimizationKV cachequantizationvLLMTensorRT-LLM

Deploying LLMs in production requires careful optimization to balance latency, throughput, and cost. This comprehensive guide covers every major technique for making LLM inference fast and efficient.

Understanding LLM Inference

The Autoregressive Bottleneck

LLMs generate text one token at a time:

# Pseudocode for autoregressive generation
def generate(prompt, max_tokens):
    tokens = tokenize(prompt)

    for _ in range(max_tokens):
        # Full forward pass for ONE token
        logits = model(tokens)
        next_token = sample(logits[-1])
        tokens.append(next_token)

        if next_token == EOS:
            break

    return tokens

This creates two distinct phases:

PhaseCompute PatternBottleneck
PrefillProcess all prompt tokensCompute-bound
DecodeGenerate one token at a timeMemory-bound

Key Metrics

MetricDefinitionTarget Range
Time to First Token (TTFT)Latency until first token100-500ms
Inter-Token Latency (ITL)Time between tokens20-50ms
ThroughputTokens per second50-500+ tok/s
Tokens per DollarCost efficiencyMaximize

KV Cache: The Foundation

How KV Cache Works

Without caching, each new token requires recomputing attention over all previous tokens:

# Without KV cache: O(n²) per token, O(n³) total
for i in range(seq_len):
    for j in range(i):
        attention[i, j] = compute(Q[i], K[j], V[j])

With KV cache, we store and reuse Key and Value projections:

# With KV cache: O(n) per token, O(n²) total
kv_cache = []
for i in range(seq_len):
    k_i, v_i = project(hidden[i])
    kv_cache.append((k_i, v_i))

    # Attend to all cached keys/values
    attention[i] = compute(Q[i], kv_cache)

KV Cache Memory

KV Cache Size = 2 × num_layers × num_heads × head_dim × seq_len × batch_size × bytes

For LLaMA-2 70B with 4K context:

  • 2 × 80 layers × 64 heads × 128 dim × 4096 tokens × 2 bytes = 5.2 GB per sequence

KV Cache Optimizations

1. Multi-Query Attention (MQA)

Share K,V across attention heads:

# Standard: Each head has own K, V
# K, V shape: [batch, num_heads, seq, head_dim]

# MQA: Single K, V for all heads
# K, V shape: [batch, 1, seq, head_dim]
# Reduces KV cache by num_heads (e.g., 32x)

2. Grouped-Query Attention (GQA)

LLaMA 2 and newer use GQA—a middle ground:

# GQA: Groups of heads share K, V
# num_kv_heads = num_heads // group_size
# LLaMA-2 70B: 64 heads, 8 KV heads → 8x reduction

3. Paged Attention (vLLM)

Manage KV cache like virtual memory:

# Traditional: Contiguous pre-allocated cache
cache = torch.zeros(max_seq_len, hidden_dim)  # Wasteful!

# Paged: Allocate blocks on demand
block_table = {}  # Maps logical → physical blocks
def get_block(seq_id, block_idx):
    if (seq_id, block_idx) not in block_table:
        block_table[(seq_id, block_idx)] = allocate_block()
    return block_table[(seq_id, block_idx)]

Quantization for Inference

Quantization Methods Comparison

MethodBitsSpeedQualityUse Case
FP16161.0xBestQuality-critical
INT881.5xExcellentBalanced
INT4 (GPTQ)42.0xGoodMemory-limited
INT4 (AWQ)42.0xBetterProduction
GGUF Q441.8xGoodCPU inference

GPTQ Quantization

Post-training quantization using calibration data:

from transformers import AutoModelForCausalLM, GPTQConfig

gptq_config = GPTQConfig(
    bits=4,
    dataset="c4",
    group_size=128,
    desc_act=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=gptq_config,
    device_map="auto",
)

AWQ (Activation-aware Weight Quantization)

Preserves important weights based on activation patterns:

from awq import AutoAWQForCausalLM

model = AutoAWQForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
)

model.quantize(
    tokenizer,
    quant_config={
        "w_bit": 4,
        "q_group_size": 128,
        "zero_point": True,
    }
)

Memory Savings from Quantization

ModelFP16INT8INT4
7B14 GB7 GB3.5 GB
13B26 GB13 GB6.5 GB
70B140 GB70 GB35 GB

Batching Strategies

Static Batching

Simple but inefficient—wait for all sequences:

# Static batching
def static_batch_generate(prompts, max_tokens):
    # Pad all to same length
    padded = pad_sequences(prompts)

    # Generate for fixed steps
    for _ in range(max_tokens):
        outputs = model(padded)
        # All sequences generate same number of tokens

Problem: Short sequences wait for long ones.

Continuous Batching

Add/remove sequences dynamically:

# Continuous batching (vLLM, TGI)
class ContinuousBatcher:
    def __init__(self):
        self.active_sequences = []
        self.waiting_queue = []

    def step(self):
        # Generate one token for all active sequences
        outputs = model.generate_step(self.active_sequences)

        # Remove finished sequences
        finished = [s for s in self.active_sequences if s.is_done()]
        self.active_sequences = [s for s in self.active_sequences if not s.is_done()]

        # Add new sequences from queue
        while self.waiting_queue and len(self.active_sequences) < max_batch:
            self.active_sequences.append(self.waiting_queue.pop(0))

        return finished

Throughput Comparison

BatchingThroughputGPU Utilization
No batching30 tok/s5%
Static (batch=8)150 tok/s25%
Continuous (batch=32)800 tok/s70%
Continuous + PagedAttn1500 tok/s85%

Speculative Decoding

Use a small model to draft tokens, verify with large model:

def speculative_decode(prompt, draft_model, target_model, k=4):
    tokens = tokenize(prompt)

    while not done:
        # Draft: Generate k tokens with small model (fast)
        draft_tokens = draft_model.generate(tokens, num_tokens=k)

        # Verify: Check all k tokens in parallel with large model
        target_logits = target_model(tokens + draft_tokens)

        # Accept matching tokens, reject from first mismatch
        accepted = verify_and_accept(draft_tokens, target_logits)
        tokens.extend(accepted)

Speculative Decoding Speedup

Draft ModelTarget ModelAcceptance RateSpeedup
68M7B70%2.1x
160M7B80%2.5x
1B70B75%2.8x

Key insight: Large model processes k tokens in parallel (same cost as 1 token).

Serving Frameworks

vLLM

High-throughput serving with PagedAttention:

from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Llama-2-7b-hf",
    tensor_parallel_size=1,
    gpu_memory_utilization=0.9,
    max_num_batched_tokens=8192,
)

sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=256,
)

outputs = llm.generate(prompts, sampling_params)

TensorRT-LLM

NVIDIA's optimized inference engine:

# Build optimized engine
from tensorrt_llm import Builder

builder = Builder()
engine = builder.build(
    model_dir="llama-7b",
    dtype="float16",
    max_batch_size=32,
    max_input_len=2048,
    max_output_len=512,
)

# Run inference
outputs = engine.generate(
    input_ids,
    max_new_tokens=256,
    temperature=0.7,
)

Text Generation Inference (TGI)

Hugging Face's production server:

# Docker deployment
docker run --gpus all -p 8080:80 \
    ghcr.io/huggingface/text-generation-inference:latest \
    --model-id meta-llama/Llama-2-7b-hf \
    --quantize bitsandbytes-nf4 \
    --max-batch-prefill-tokens 4096

Framework Comparison

FeaturevLLMTensorRT-LLMTGI
PagedAttentionYesYesYes
Continuous BatchingYesYesYes
Speculative DecodingYesYesNo
Multi-GPUYesYesYes
QuantizationAWQ, GPTQFP8, INT8, INT4BnB, GPTQ
Setup ComplexityLowHighMedium

Flash Attention for Inference

Flash Attention speeds up both prefill and decode:

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.float16,
    device_map="auto",
)

# Automatic Flash Attention in forward pass
output = model.generate(input_ids, max_new_tokens=100)

Impact on Inference

Sequence LengthStandard AttentionFlash AttentionSpeedup
51215ms8ms1.9x
204889ms32ms2.8x
81921420ms128ms11x
32768OOM512ms

torch.compile for Inference

PyTorch 2.0's compiler provides free speedups:

import torch

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
)

# Compile for inference
model = torch.compile(model, mode="reduce-overhead")

# First call triggers compilation (slow)
# Subsequent calls are optimized
output = model.generate(input_ids, max_new_tokens=100)

Compile Modes

ModeCompilation TimeRuntime SpeedUse Case
defaultMedium1.3xGeneral
reduce-overheadLonger1.5xLatency-critical
max-autotuneVery long1.7xProduction deploy

Production Optimization Checklist

Memory Optimization

# 1. Use quantization
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
)

# 2. Enable Flash Attention
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    attn_implementation="flash_attention_2",
)

# 3. Optimize KV cache
# Use GQA models (LLaMA 2, Mistral)
# Enable paged attention (vLLM)

Latency Optimization

# 1. Compile model
model = torch.compile(model, mode="reduce-overhead")

# 2. Use CUDA graphs (for fixed shapes)
# Automatically enabled in vLLM, TensorRT-LLM

# 3. Speculative decoding for interactive use
# Draft model + target model verification

Throughput Optimization

# 1. Continuous batching
# Use vLLM or TGI instead of naive batching

# 2. Maximize batch size
# Profile to find optimal batch size for your GPU

# 3. Pipeline parallelism for very large models
tensor_parallel_size = 4  # Split across GPUs

Benchmarks

Single GPU Performance (A100 80GB)

LLaMA-2 7B, 2K input + 256 output tokens:

ConfigurationTTFTITLThroughput
Naive PyTorch850ms45ms22 tok/s
+ Flash Attention320ms28ms35 tok/s
+ torch.compile280ms22ms45 tok/s
+ INT4 Quantization180ms15ms65 tok/s
vLLM (batch=32)400ms8ms420 tok/s
TensorRT-LLM150ms6ms580 tok/s

Multi-GPU Scaling

LLaMA-2 70B throughput (tokens/second):

GPUsTensor ParallelPipeline ParallelCombined
1OOMN/AN/A
285 tok/sN/A85 tok/s
4180 tok/s160 tok/s220 tok/s
8320 tok/s280 tok/s420 tok/s

Cost Optimization

Tokens per Dollar (approximate, cloud pricing)

SetupCost/hourThroughputTokens/$
A100 40GB (vLLM)$3.50400 tok/s411K
A100 80GB (vLLM)$5.00600 tok/s432K
4x A10G (TGI)$5.60800 tok/s514K
H100 (TensorRT)$8.001500 tok/s675K

Right-sizing Recommendations

Use CaseModel SizeHardwareFramework
Chatbot (low latency)7BA10GvLLM
Batch processing7-13BA100TensorRT-LLM
High quality70B8x A100vLLM + TP
Cost-sensitive7B INT4T4TGI

References

  1. Kwon, W., et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP 2023

  2. vLLM Comparative Analysis. (2025). "Comparative Analysis of Large Language Model Inference Serving Systems." arXiv:2511.17593

  3. PagedAttention + FlexAttention. (2025). "Paged Attention Meets FlexAttention: Unlocking Long-Context Efficiency." arXiv:2506.07311

  4. Frantar, E., et al. (2023). "GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers." ICLR 2023

  5. Lin, J., et al. (2023). "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration." arXiv:2306.00978

  6. NVIDIA. (2025). "TensorRT-LLM." GitHub

Frequently Asked Questions

Related Articles

Need Flash Attention wheels?

Skip the 30+ minute compilation. Find prebuilt wheels for your exact configuration.

Find Your Wheel