IQ Lab
← all posts
AI 2026.04.28 · 10 min read Advanced

RNN 학습은 왜 이렇게 설계되었는가

Unrolled graph에서 BPTT 유도, truncated BPTT의 메모리 절약, 복잡도 분석, RTRL까지 — RNN 학습 알고리즘의 통일된 설계 철학을 추적한다.


RNN은 본질적으로 cyclic 구조다 — ht=f(ht1,xt)h_t = f(h_{t-1}, x_t)가 자기 자신을 참조한다. 그런데 학습은 acyclic DAG 위의 chain rule을 요구한다. 이 모순을 어떻게 해결하는가? 그리고 그 해결책이 메모리 한계, 병렬성 부재, online 학습 불가라는 RNN의 세 가지 근본 제약을 어떻게 만들어내는가?

Unrolling — cyclic을 acyclic으로

해결책은 단순하다. 시간축으로 펼쳐서 DAG로 만든다.

Cyclic (RNN cell):          Unrolled (DAG):

   ┌──────────┐         x₁→[cell]→h₁→[cell]→h₂→[cell]→h₃
   x → [cell] →h                  │       │       │
       └──┘                       y₁      y₂      y₃
       self-loop         (모든 cell이 같은 W_hh, W_xh 공유)

Unrolling 후 각 hth_t는 DAG의 distinct 노드가 된다. ht=f(ht1,xt)h_t = f(h_{t-1}, x_t)가 명시적 edge다. 이제 chain rule이 잘 정의된다.

핵심은 shared weight다. 모든 cell이 동일한 WhhW_{hh}를 사용하므로, backward 시 gradient는 각 time step의 기여를 합산한다.

LWhh=t=1TLWhhstep t\frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^{T} \left.\frac{\partial L}{\partial W_{hh}}\right|_{\text{step } t}

PyTorch autograd가 동일 variable의 multiple usage를 자동 합산하는 이유가 여기 있다. Unrolled graph의 topology가 합산을 강제한다.

BPTT — 유도하면 보이는 것

BPTT(Backpropagation Through Time)는 unrolled graph 위의 표준 backprop이다. 이름만 다를 뿐이다.

delta(backward signal) δt=L/ht\delta_t = \partial L / \partial h_t를 정의하면 recursive form이 나온다.

δt=Ltht+Jt+1δt+1,Jt=diag(tanh(zt))Whh\delta_t = \frac{\partial L_t}{\partial h_t} + J_{t+1}^\top \delta_{t+1}, \quad J_t = \mathrm{diag}(\tanh'(z_t))\, W_{hh}

hth_t는 두 경로로 loss에 기여한다 — 즉각적인 LtL_t, 그리고 ht+1h_{t+1}을 통한 모든 미래 loss. 그 합이 δt\delta_t다.

weight gradient는 이 delta의 합산이다.

LWhh=t=1T(δttanh(zt))ht1\frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^{T} (\delta_t \odot \tanh'(z_t))\, h_{t-1}^\top
명제 1 · Jacobian 곱 누적

LtL_tWhhW_{hh}에 대한 gradient는 시간을 거슬러 올라가는 모든 path의 합이다.

LtWhh=k=1t(j=k+1tJj)Lththk1tanh(zk)\frac{\partial L_t}{\partial W_{hh}} = \sum_{k=1}^{t} \left(\prod_{j=k+1}^{t} J_j^\top\right) \frac{\partial L_t}{\partial h_t}\, h_{k-1}^\top \tanh'(z_k)
▷ 증명

Lththt1hkWhhL_t \leftarrow h_t \leftarrow h_{t-1} \leftarrow \cdots \leftarrow h_k \leftarrow W_{hh}의 모든 path에 chain rule을 적용하면, 각 path가 Jacobian 곱 j=k+1tJj\prod_{j=k+1}^{t} J_j를 생성한다. Shared weight이므로 모든 kk에 대해 합산한다. \square

이 Jacobian 곱 Jj\prod J_j가 vanishing/exploding gradient의 핵심 항이다. spectral radius ρ(Jj)<1\rho(\prod J_j) < 1이면 long-range gradient가 exponential decay한다.

Truncated BPTT — 편향을 감수하는 이유

Full BPTT는 정확하지만 두 가지 문제가 있다. 메모리가 O(TH)O(TH)이고, sequence가 끝나야 update할 수 있다.

Truncated BPTT(kk)는 단순하다. 매 kk step마다 backward를 실행하고, hidden state를 detach()로 끊고, 다음 chunk를 forward한다.

h = h.detach()  # gradient flow 차단 — 이전 chunk는 무시

logits, h = model(x_chunk, h)
loss.backward()

메모리는 O(kH)O(kH)로 줄고, 매 kk step마다 update가 발생한다. 대가는 bias다 — kk step 이전의 long-range gradient를 무시한다.

트레이드오프

Bias의 크기는 O(ρk)O(\rho^k) — spectral radius ρ<1\rho < 1이면 kk에 대해 exponentially 감소한다. Karpathy의 char-RNN이 k=25k = 25를 선택한 이유가 여기 있다. Plain RNN에서 ρ<1\rho < 1이면 25 step 이후의 gradient는 어차피 작다. LSTM에서는 effective ρ\rho가 1에 가까워지므로 더 큰 kk가 의미를 갖는다.

kkMemoryBiasTypical use
25작음작음Char-RNN (Karpathy)
35중간매우 작음PyTorch LM tutorial
TT (full)O(TH)O(TH)0짧은 sequence

복잡도 — 병렬성이 없는 구조

BPTT의 시간 복잡도는 O(TH2)O(TH^2), 메모리는 O(TH)O(TH)다. 이 수치 자체보다 중요한 것은 sequential dependency다.

t=0 → t=1 → t=2 → ... → t=T
 │     │     │           │
 한 step이 끝나야 다음 step 시작 가능

batch 차원은 GPU에서 병렬화된다. sequence 차원은 본질적으로 sequential이다. GPU의 수천 개 core가 다음 step을 기다리며 대기한다.

Gradient checkpointing(Chen 2016)은 메모리와 시간을 tradeoff한다. T\sqrt{T}개의 checkpoint만 보존하고, backward 시 각 segment를 forward 재실행한다.

s=T    M=O(TH),Tcost=O(TH21.5)s^* = \sqrt{T} \implies M = O(\sqrt{T} H), \quad T_{\text{cost}} = O(TH^2 \cdot 1.5)

메모리는 T\sqrt{T}배 절약, 시간은 약 1.5배 증가다. T=10000T = 10000이면 메모리 100배 절약이 가능하다.

RTRL — forward로 같은 gradient를

RTRL(Williams & Zipser 1989)은 BPTT와 동일한 gradient를 forward-mode AD로 계산한다.

핵심은 sensitivity matrix St=ht/θS_t = \partial h_t / \partial \theta를 forward와 함께 propagate하는 것이다.

St=JtSt1+fθpartialS_t = J_t S_{t-1} + \frac{\partial f}{\partial \theta}\bigg|_{\text{partial}}

매 step tt에서 Lt/θ=(Lt/ht)St\partial L_t / \partial \theta = (\partial L_t / \partial h_t) \cdot S_t를 즉시 계산해 update한다. sequence가 끝날 때까지 기다리지 않는다.

대가는 O(H4)O(H^4)의 per-step 복잡도다. StRH×θS_t \in \mathbb{R}^{H \times |\theta|}이고 θ=O(H2)|\theta| = O(H^2)이므로, StS_t의 update에 O(H4)O(H^4)이 필요하다. BPTT의 O(H2)O(H^2)와 비교하면 H2H^2배 더 느리다.

UORO(Tallec & Ollivier 2017)는 이를 rank-1 random projection으로 근사한다. StutvtS_t \approx u_t v_t^\top으로 표현하면 per-step 복잡도가 O(H2)O(H^2)로 줄고, expectation은 unbiased로 유지된다.

정리

다섯 챕터의 핵심을 한 문장으로 압축하면: RNN 학습의 모든 설계 결정은 “cyclic 구조를 acyclic으로 펼치는 대가”를 어떻게 치를 것인가의 선택이다.

  • Unrolling은 DAG를 만들어 chain rule을 가능하게 하지만, O(TH)O(TH) 메모리를 요구한다.
  • Shared weight의 gradient 합산은 BPTT를 standard backprop으로 환원시키지만, Jacobian 곱의 누적이 vanishing/exploding을 만든다.
  • Truncated BPTT는 메모리를 O(kH)O(kH)로 줄이지만, long-range gradient를 희생한다. Bias는 O(ρk)O(\rho^k)로 exponentially 감소한다.
  • Sequential dependency는 GPU 병렬성을 막는다. 이것이 Transformer의 동기다.
  • RTRL은 online update를 가능하게 하지만 O(H4)O(H^4) 비용을 치른다. UORO가 이를 O(H2)O(H^2)로 근사한다.

다음 글에서는 Jacobian 곱 Jj\prod J_j의 spectral 분석으로 vanishing/exploding gradient의 정확한 조건을 추적한다.

REF
Tallec & Ollivier · 2017 · Training Recurrent Networks Online without Backtracking · arXiv
REF