Multi-Query Attention (MQA)

VRAM:8-32x smaller KV-cache

TL;DR

Single KV head shared across all query heads. Reduces KV-cache by 8-32x. Used in Falcon, PaLM.

Use when

  • +Long context inference
  • +Memory-constrained deployment

Skip when

  • -Using pre-existing MHA model (architecture is fixed at training)

Multi-Query Attention uses a single key-value head shared across all query heads. This dramatically reduces KV-cache memory and improves inference speed.

How It Works

Standard multi-head attention: each head has its own Q, K, V projections. MQA: each head has its own Q, but all heads share K and V.

Architecture Impact

For a model with 32 attention heads: - **MHA**: 32 KV heads → 32x cache - **MQA**: 1 KV head → 1x cache

Key Benefits

- **Memory**: 8-32x smaller KV-cache - **Speed**: Faster attention due to smaller cache - **Tradeoff**: Slight quality decrease