RNN 학습은 왜 이렇게 설계되었는가
Unrolled graph에서 BPTT 유도, truncated BPTT의 메모리 절약, 복잡도 분석, RTRL까지 — RNN 학습 알고리즘의 통일된 설계 철학을 추적한다.
RNN은 본질적으로 cyclic 구조다 — 가 자기 자신을 참조한다. 그런데 학습은 acyclic DAG 위의 chain rule을 요구한다. 이 모순을 어떻게 해결하는가? 그리고 그 해결책이 메모리 한계, 병렬성 부재, online 학습 불가라는 RNN의 세 가지 근본 제약을 어떻게 만들어내는가?
Unrolling — cyclic을 acyclic으로
해결책은 단순하다. 시간축으로 펼쳐서 DAG로 만든다.
Cyclic (RNN cell): Unrolled (DAG):
┌──────────┐ x₁→[cell]→h₁→[cell]→h₂→[cell]→h₃
x → [cell] →h │ │ │
└──┘ y₁ y₂ y₃
self-loop (모든 cell이 같은 W_hh, W_xh 공유)
Unrolling 후 각 는 DAG의 distinct 노드가 된다. 가 명시적 edge다. 이제 chain rule이 잘 정의된다.
핵심은 shared weight다. 모든 cell이 동일한 를 사용하므로, backward 시 gradient는 각 time step의 기여를 합산한다.
PyTorch autograd가 동일 variable의 multiple usage를 자동 합산하는 이유가 여기 있다. Unrolled graph의 topology가 합산을 강제한다.
BPTT — 유도하면 보이는 것
BPTT(Backpropagation Through Time)는 unrolled graph 위의 표준 backprop이다. 이름만 다를 뿐이다.
delta(backward signal) 를 정의하면 recursive form이 나온다.
는 두 경로로 loss에 기여한다 — 즉각적인 , 그리고 을 통한 모든 미래 loss. 그 합이 다.
weight gradient는 이 delta의 합산이다.
의 에 대한 gradient는 시간을 거슬러 올라가는 모든 path의 합이다.
의 모든 path에 chain rule을 적용하면, 각 path가 Jacobian 곱 를 생성한다. Shared weight이므로 모든 에 대해 합산한다.
이 Jacobian 곱 가 vanishing/exploding gradient의 핵심 항이다. spectral radius 이면 long-range gradient가 exponential decay한다.
Truncated BPTT — 편향을 감수하는 이유
Full BPTT는 정확하지만 두 가지 문제가 있다. 메모리가 이고, sequence가 끝나야 update할 수 있다.
Truncated BPTT()는 단순하다. 매 step마다 backward를 실행하고, hidden state를 detach()로 끊고, 다음 chunk를 forward한다.
h = h.detach() # gradient flow 차단 — 이전 chunk는 무시
logits, h = model(x_chunk, h)
loss.backward()
메모리는 로 줄고, 매 step마다 update가 발생한다. 대가는 bias다 — step 이전의 long-range gradient를 무시한다.
Bias의 크기는 — spectral radius 이면 에 대해 exponentially 감소한다. Karpathy의 char-RNN이 를 선택한 이유가 여기 있다. Plain RNN에서 이면 25 step 이후의 gradient는 어차피 작다. LSTM에서는 effective 가 1에 가까워지므로 더 큰 가 의미를 갖는다.
| Memory | Bias | Typical use | |
|---|---|---|---|
| 25 | 작음 | 작음 | Char-RNN (Karpathy) |
| 35 | 중간 | 매우 작음 | PyTorch LM tutorial |
| (full) | 0 | 짧은 sequence |
복잡도 — 병렬성이 없는 구조
BPTT의 시간 복잡도는 , 메모리는 다. 이 수치 자체보다 중요한 것은 sequential dependency다.
t=0 → t=1 → t=2 → ... → t=T
│ │ │ │
한 step이 끝나야 다음 step 시작 가능
batch 차원은 GPU에서 병렬화된다. sequence 차원은 본질적으로 sequential이다. GPU의 수천 개 core가 다음 step을 기다리며 대기한다.
Gradient checkpointing(Chen 2016)은 메모리와 시간을 tradeoff한다. 개의 checkpoint만 보존하고, backward 시 각 segment를 forward 재실행한다.
메모리는 배 절약, 시간은 약 1.5배 증가다. 이면 메모리 100배 절약이 가능하다.
RTRL — forward로 같은 gradient를
RTRL(Williams & Zipser 1989)은 BPTT와 동일한 gradient를 forward-mode AD로 계산한다.
핵심은 sensitivity matrix 를 forward와 함께 propagate하는 것이다.
매 step 에서 를 즉시 계산해 update한다. sequence가 끝날 때까지 기다리지 않는다.
대가는 의 per-step 복잡도다. 이고 이므로, 의 update에 이 필요하다. BPTT의 와 비교하면 배 더 느리다.
UORO(Tallec & Ollivier 2017)는 이를 rank-1 random projection으로 근사한다. 으로 표현하면 per-step 복잡도가 로 줄고, expectation은 unbiased로 유지된다.
정리
다섯 챕터의 핵심을 한 문장으로 압축하면: RNN 학습의 모든 설계 결정은 “cyclic 구조를 acyclic으로 펼치는 대가”를 어떻게 치를 것인가의 선택이다.
- Unrolling은 DAG를 만들어 chain rule을 가능하게 하지만, 메모리를 요구한다.
- Shared weight의 gradient 합산은 BPTT를 standard backprop으로 환원시키지만, Jacobian 곱의 누적이 vanishing/exploding을 만든다.
- Truncated BPTT는 메모리를 로 줄이지만, long-range gradient를 희생한다. Bias는 로 exponentially 감소한다.
- Sequential dependency는 GPU 병렬성을 막는다. 이것이 Transformer의 동기다.
- RTRL은 online update를 가능하게 하지만 비용을 치른다. UORO가 이를 로 근사한다.
다음 글에서는 Jacobian 곱 의 spectral 분석으로 vanishing/exploding gradient의 정확한 조건을 추적한다.