Attention의 O(T²) 벽을 어떻게 부수는가
Self-attention의 이차 복잡도가 만드는 메모리·시간 병목의 근원부터, Linear·Sparse·Flash·MQA/GQA 네 가지 해법의 수학적 원리와 트레이드오프까지 추적한다.
- 02 Transformer Block은 왜 이 네 요소의 조합인가
- 03 Positional Encoding은 어떻게 진화했나
- 04 Transformer 훈련을 가능하게 하는 다섯 가지 설계 결정
- 05 Attention의 O(T²) 벽을 어떻게 부수는가
- 06 BERT, GPT, T5, ViT, MoE — 다섯 아키텍처는 하나의 질문에 답한다
- 07 LLM은 왜 클수록 똑똑한가 — Scaling Laws의 세계
Transformer의 attention은 시간과 메모리를 요구한다. 이면 attention matrix 하나가 2GB를 넘고, 이면 32GB를 넘는다. 단일 GPU에 들어가지 않는다. 그렇다면 1M 토큰 컨텍스트를 처리하는 현대 LLM은 어떻게 이 벽을 넘었는가?
병목의 정체
Standard self-attention의 연산은 세 단계다.
각 행렬 곱이 FLOP을 소비하고, 그 결과물인 가 GPU 메모리에 상주한다. 헤드가 개이면 메모리는 으로 늘어난다.
Sequence length T: 2048 8192 32768 131072
Attention matrix: 16MB 256MB 4 GB 64 GB
(BF16, single head)
시간 복잡도도 같은 구조다. 가 4배 늘면 연산량은 16배 늘어난다. 이 이차 blow-up이 long context 모델의 근본 제약이다.
KV cache는 또 다른 병목이다. Autoregressive generation에서 매 스텝마다 이전 토큰의 K, V를 누적한다.
LLaMA-2 70B()의 32K 컨텍스트에서 KV cache만 80GB에 달한다. 모델 가중치와 동급이다.
Linear Attention — 결합 순서 하나로 O(T)로
Katharopoulos 2020이 제안한 아이디어는 단순하다. 행렬 곱의 결합 순서를 바꾼다.
는 행렬이고, 와 무관하다. Softmax가 있으면 이 분리가 불가능하므로, softmax를 양수 feature map 로 대체한다.
Katharopoulos는 을 제안했다. 이때 시간 복잡도는 , 메모리는 로 내려간다. 이면 이차 항이 사라진다.
Causal linear attention의 incremental computation은 고정 크기 상태 를 갖는 RNN과 동치다.
상태를 , 로 정의하면, 스텝 의 출력은 이다. 의 점화식이 성립하므로, 각 스텝의 계산이 인 RNN 형태가 된다.
대가는 명확하다. Softmax의 날카로운 피크(sharp peak)가 사라진다. 특정 토큰에 집중해야 하는 fine-grained retrieval에서 성능이 떨어진다. Performer(Choromanski 2021)는 이 손실을 random feature로 보완한다 — 를 선택하면 가 성립해 softmax kernel의 불편 추정이 된다. 그러나 분산이 에 반비례하므로 가 작으면 추정이 불안정하다.
Sparse Attention — 필요한 쌍만 계산한다
Longformer(Beltagy 2020)와 BigBird(Zaheer 2020)는 다른 방향을 택한다. 모든 쌍을 계산하는 대신, 미리 정의한 sparse pattern만 계산한다.
BigBird pattern:
Local window: |i - j| ≤ w
Global tokens: CLS, special tokens (양방향)
Random: 각 토큰이 랜덤 r개 토큰에 attend
Local이 근거리 의존성을 잡고, global이 허브 역할을 하며, random이 임의의 long-range 경로를 만든다. 이 random 연결이 핵심이다.
Local + Global + Random sparse pattern의 결합이 expander graph를 형성하면, 충분한 깊이의 BigBird는 임의의 연속 시퀀스-투-시퀀스 함수를 근사한다.
임의 토큰 쌍의 최단 경로가 이므로, 레이어로 dense attention과 동등한 정보 전파가 가능하다는 것이 증명의 핵심이다.
이론상 sparse 연산이 빠르지만, naive 구현은 GPU에서 dense보다 느릴 수 있다. Irregular memory access가 Tensor Core를 유휴 상태로 만들기 때문이다. 실제 효율은 block-sparse나 sliding window처럼 구조적 sparsity를 Triton/CUDA로 구현할 때만 얻을 수 있다.
Mistral 7B가 sliding window attention을 채택한 이유가 여기 있다. 완전히 random한 BigBird 패턴보다 block-friendly한 local window가 GPU에서 훨씬 효율적이다.
Flash Attention — 같은 FLOP, 다른 메모리 경로
Flash Attention(Dao 2022)은 복잡도를 바꾸지 않는다. 여전히 FLOP을 수행한다. 그런데 어떻게 2-4배 빠른가?
GPU의 메모리 계층을 직접 겨냥한다.
HBM (80 GB, ~600 cycles) ← Standard attention의 병목
SRAM (256 KB, ~10 cycles) ← Flash Attention의 작업 공간
Standard attention은 크기의 행렬을 HBM에 쓰고 읽는다. HBM 전송이 발생한다. Flash Attention은 Q, K, V를 SRAM에 맞는 블록으로 나누고, 블록 안에서 모든 연산을 완료한 뒤 출력만 HBM에 쓴다.
이를 위해 online softmax가 필요하다. 전체 행을 보지 않고 블록 단위로 분모를 점진적으로 갱신한다.
이 업데이트가 정확히 standard softmax와 동일한 결과를 내는 것이 증명 가능하다. 즉 Flash Attention은 근사가 아니다. 표현력 손실 없이 메모리를 에서 로 줄이고 속도를 2-4배 높인다.
HBM IO 복잡도는 으로, 이 SRAM 크기다. 블록이 클수록 HBM 접근이 줄어든다. 에서 attention matrix 메모리가 256MB에서 4MB로 내려간다.
MQA / GQA — KV Cache를 줄인다
Flash Attention이 학습의 게임 체인저라면, MQA/GQA는 추론의 게임 체인저다.
Multi-Query Attention(Shazeer 2019)은 개의 Q-head가 단 하나의 K, V를 공유한다. Grouped-Query Attention(Ainslie 2023)은 이를 일반화해 개의 KV 그룹을 둔다.
LLaMA-2 70B는 , (GQA-8)을 채택한다. KV cache가 8배 줄어든다. 32K 컨텍스트에서 43GB 대신 5.4GB.
왜 추론에서만 효과가 큰가? 학습 시에는 배치 단위로 compute-bound이지만, generation 시에는 매 스텝마다 전체 KV cache를 HBM에서 불러오는 memory-bandwidth bound 상황이 된다. HBM bandwidth가 1.5 TB/s인 A100에서 KV cache 로드 시간이 토큰 생성의 지배적 비용이다. GQA-8은 이 로드 시간을 8배 줄인다.
표현력 손실은 작다. Q-head는 여전히 독립적인 를 가지므로, 같은 KV 공간에서 다른 쿼리로 다른 attention 패턴을 형성한다. 실증적으로 70B 규모에서는 정확도 손실이 0.5% 이내다.
트레이드오프
네 가지 방법은 각자 다른 축에서 최적화한다.
| 방법 | 시간 | 메모리 | 정확성 | 핵심 제약 |
|---|---|---|---|---|
| Standard | Exact | 에 이차 | ||
| Linear | 근사 | Sharpness 손실 | ||
| Sparse | 제한적 | GPU 구현 어려움 | ||
| Flash | 유효 | Exact | GPU 전용 | |
| MQA/GQA | Exact | 표현력 소폭 손실 |
정리
- 병목은 시간과 메모리 두 축에서 동시에 발생한다. KV cache는 세 번째 병목이다.
- Linear Attention은 를 달성하지만 softmax의 selectivity를 잃는다.
- Flash Attention은 같은 FLOP으로 IO를 줄여 2-4배를 얻는다. 근사 없이.
- GQA는 KV head를 줄여 추론 throughput을 $