Multi-Query Attention (MQA)
VRAM:8-32x smaller KV-cache
TL;DR
Single KV head shared across all query heads. Reduces KV-cache by 8-32x. Used in Falcon, PaLM.
Use when
- +Long context inference
- +Memory-constrained deployment
Skip when
- -Using pre-existing MHA model (architecture is fixed at training)
Multi-Query Attention uses a single key-value head shared across all query heads. This dramatically reduces KV-cache memory and improves inference speed.
How It Works
Standard multi-head attention: each head has its own Q, K, V projections. MQA: each head has its own Q, but all heads share K and V.
Architecture Impact
For a model with 32 attention heads: - **MHA**: 32 KV heads → 32x cache - **MQA**: 1 KV head → 1x cache
Key Benefits
- **Memory**: 8-32x smaller KV-cache - **Speed**: Faster attention due to smaller cache - **Tradeoff**: Slight quality decrease