Inference Optimization

Running a 70B parameter model requires 140GB of VRAM in FP16, which means you need at least two A100 80GB GPUs (~$30k hardware). This is prohibitively expensive for most applications.

In this chapter, we will explore techniques to compress models and speed up inference without significantly sacrificing quality.

1. Quantization: Doing More with Less

Quantization reduces the precision of model weights from 16-bit floating point (FP16/BF16) to lower bit-widths like 8-bit or 4-bit integers.

Precision Comparison

  • FP16 (Half Precision): 16 bits per weight. The standard for training.
  • Int8: 8 bits. 2x memory reduction. Negligible accuracy loss.
  • Int4: 4 bits. 4x memory reduction. Slight accuracy loss, but allows running 70B models on a single GPU.

Methods

  1. Post-Training Quantization (PTQ): Quantize after training. E.g., GPTQ, AWQ.
  2. Quantization-Aware Training (QAT): Retrain with quantization to recover accuracy.

[!TIP] AWQ (Activation-aware Weight Quantization) is currently the state-of-the-art for 4-bit serving. It preserves the most “important” weights (those with large activations) in higher precision or scales them carefully.

Code: Loading a Quantized Model

Using bitsandbytes to load a model in 4-bit precision.

# pip install transformers bitsandbytes accelerate
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

# Define 4-bit configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4", # Normal Float 4 (optimized for weights)
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    quantization_config=bnb_config,
    device_map="auto"
)

print(f"Model footprint: {model.get_memory_footprint() / 1e9:.2f} GB")
# Output: ~4.5 GB (vs ~14 GB for FP16)

2. Flash Attention

Attention is O(N^2) with respect to sequence length. For long contexts (e.g., RAG with 10k tokens), the attention matrix becomes huge.

Flash Attention (Dao et al.) is an IO-aware exact attention algorithm. It minimizes memory reads/writes between the GPU’s fast on-chip SRAM and slow HBM (High Bandwidth Memory).

  • Standard Attention: Reads N x N matrix from HBM multiple times.
  • Flash Attention: Tiles the computation to keep data in SRAM, reducing HBM access by up to 10x.

[!NOTE] You usually don’t need to implement Flash Attention yourself. It’s integrated into PyTorch 2.0 (F.scaled_dot_product_attention) and libraries like vLLM.

3. KV Cache & Memory Management

As discussed in the Serving chapter, the KV Cache grows with every generated token. Efficiently managing this memory is key to throughput.

Interactive: KV Cache Memory Manager

Visualize how PagedAttention allocates non-contiguous memory blocks (pages) for the KV cache, preventing fragmentation.

Physical VRAM (Pages)

Used: 0%

Logical Sequences (Requests)

4. Speculative Decoding

Speculative Decoding leverages the fact that large models are slow but small models are fast.

Concept:

  1. Use a small Draft Model (e.g., Llama-7B) to quickly generate 5 candidate tokens.
  2. Use the large Target Model (e.g., Llama-70B) to verify all 5 tokens in a single forward pass (parallel).
  3. Accept the valid tokens and discard the rest.

If the draft model is accurate enough, this can speed up inference by 2-3x without changing the final output distribution.

The Trade-off

  • Good Draft Model: High acceptance rate → High speedup.
  • Bad Draft Model: Low acceptance rate → Slower than standard inference (overhead).

5. Summary

To run models in production efficiently:

  1. Quantize: Use AWQ (4-bit) to reduce VRAM by 4x.
  2. Use Flash Attention: Reduce memory bandwidth bottlenecks.
  3. Manage KV Cache: Use PagedAttention (via vLLM) to maximize batch size.
  4. Speculative Decoding: If latency is critical, use a draft model.

Next, we will look at the critical layer of Safety and Moderation to prevent your optimized model from causing harm.