Inference Optimization
Running a 70B parameter model requires 140GB of VRAM in FP16, which means you need at least two A100 80GB GPUs (~$30k hardware). This is prohibitively expensive for most applications.
In this chapter, we will explore techniques to compress models and speed up inference without significantly sacrificing quality.
1. Quantization: Doing More with Less
Quantization reduces the precision of model weights from 16-bit floating point (FP16/BF16) to lower bit-widths like 8-bit or 4-bit integers.
Precision Comparison
- FP16 (Half Precision): 16 bits per weight. The standard for training.
- Int8: 8 bits. 2x memory reduction. Negligible accuracy loss.
- Int4: 4 bits. 4x memory reduction. Slight accuracy loss, but allows running 70B models on a single GPU.
Methods
- Post-Training Quantization (PTQ): Quantize after training. E.g., GPTQ, AWQ.
- Quantization-Aware Training (QAT): Retrain with quantization to recover accuracy.
[!TIP] AWQ (Activation-aware Weight Quantization) is currently the state-of-the-art for 4-bit serving. It preserves the most “important” weights (those with large activations) in higher precision or scales them carefully.
Code: Loading a Quantized Model
Using bitsandbytes to load a model in 4-bit precision.
# pip install transformers bitsandbytes accelerate
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
# Define 4-bit configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Normal Float 4 (optimized for weights)
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
quantization_config=bnb_config,
device_map="auto"
)
print(f"Model footprint: {model.get_memory_footprint() / 1e9:.2f} GB")
# Output: ~4.5 GB (vs ~14 GB for FP16)
2. Flash Attention
Attention is O(N^2) with respect to sequence length. For long contexts (e.g., RAG with 10k tokens), the attention matrix becomes huge.
Flash Attention (Dao et al.) is an IO-aware exact attention algorithm. It minimizes memory reads/writes between the GPU’s fast on-chip SRAM and slow HBM (High Bandwidth Memory).
- Standard Attention: Reads N x N matrix from HBM multiple times.
- Flash Attention: Tiles the computation to keep data in SRAM, reducing HBM access by up to 10x.
[!NOTE] You usually don’t need to implement Flash Attention yourself. It’s integrated into PyTorch 2.0 (
F.scaled_dot_product_attention) and libraries like vLLM.
3. KV Cache & Memory Management
As discussed in the Serving chapter, the KV Cache grows with every generated token. Efficiently managing this memory is key to throughput.
Interactive: KV Cache Memory Manager
Visualize how PagedAttention allocates non-contiguous memory blocks (pages) for the KV cache, preventing fragmentation.
Physical VRAM (Pages)
Logical Sequences (Requests)
4. Speculative Decoding
Speculative Decoding leverages the fact that large models are slow but small models are fast.
Concept:
- Use a small Draft Model (e.g., Llama-7B) to quickly generate 5 candidate tokens.
- Use the large Target Model (e.g., Llama-70B) to verify all 5 tokens in a single forward pass (parallel).
- Accept the valid tokens and discard the rest.
If the draft model is accurate enough, this can speed up inference by 2-3x without changing the final output distribution.
The Trade-off
- Good Draft Model: High acceptance rate → High speedup.
- Bad Draft Model: Low acceptance rate → Slower than standard inference (overhead).
5. Summary
To run models in production efficiently:
- Quantize: Use AWQ (4-bit) to reduce VRAM by 4x.
- Use Flash Attention: Reduce memory bandwidth bottlenecks.
- Manage KV Cache: Use PagedAttention (via vLLM) to maximize batch size.
- Speculative Decoding: If latency is critical, use a draft model.
Next, we will look at the critical layer of Safety and Moderation to prevent your optimized model from causing harm.