Flash Attention 2
Speed:2-4x vs standard attention
TL;DR
Memory-efficient attention with IO-aware tiling. 2-4x faster than standard attention, O(1) memory overhead.
Use when
- +Any transformer inference
- +Long context lengths
- +Using modern frameworks
Skip when
- -CPU-only inference
- -Very old GPU (pre-Ampere)
Flash Attention 2 is an IO-aware exact attention algorithm that reduces memory reads/writes by tiling the computation. It's the de-facto standard for efficient transformer inference.
How It Works
Instead of materializing the full attention matrix (O(n²) memory), Flash Attention computes attention in tiles that fit in SRAM, writing only the final output to HBM.
Key Benefits
- **Memory**: O(n) instead of O(n²) for attention - **Speed**: 2-4x faster than standard attention - **Exact**: No approximation, identical outputs
Code Examples
Enable Flash Attention in Transformerspython
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype="float16"
)