IQ Lab
← all posts
AI 2026.04.28 · 10 min read Advanced

TRPO는 왜 KL을 step size로 쓰는가

단조 개선 보장을 실전에서 구현하기 위한 TRPO의 constraint 형식화부터 Natural PG 환원, Conjugate Gradient, Line Search까지 — 하나의 철학이 네 단계로 펼쳐지는 과정을 추적한다.


정책 경사 알고리즘이 불안정한 근본 이유는 step size다. 너무 크면 정책이 무너지고, 너무 작으면 학습이 멈춘다. TRPO는 이 문제를 “step size를 KL divergence로 측정하라”는 단 하나의 아이디어로 풀어낸다. 그렇다면 왜 Euclidean 거리가 아닌 KL인가, 그리고 이 제약을 어떻게 실제로 계산하는가?

왜 KL이 자연스러운 step size인가

신경망 파라미터 θ\theta의 Euclidean 거리 Δθ2\|\Delta\theta\|_2는 parameterization에 종속된다. 마지막 레이어의 weight를 2배로 늘리면 정책 분포가 크게 변하지만 Euclidean distance는 작을 수 있다. 반대로 첫 번째 레이어의 미세한 변화가 출력 분포에는 거의 영향을 주지 않을 수도 있다.

KL divergence DˉKL(πθoldπθ)\bar{D}_{\text{KL}}(\pi_{\theta_{\text{old}}} \| \pi_\theta)는 이 문제에서 자유롭다. 정책 분포 자체의 변화를 측정하므로 parameterization에 독립적이다. 이것이 TRPO의 constraint 형식화다:

θnew=argmaxθLθold(θ)s.t.DˉKL(θoldθ)δ\theta_{\text{new}} = \arg\max_\theta L_{\theta_{\text{old}}}(\theta) \quad \text{s.t.} \quad \bar{D}_{\text{KL}}(\theta_{\text{old}} \| \theta) \leq \delta

여기서 Lθold(θ)L_{\theta_{\text{old}}}(\theta)는 importance sampling 가중치로 표현한 surrogate objective다.

Lθold(θ)=Es,aπold ⁣[πθ(as)πθold(as)Aπold(s,a)]L_{\theta_{\text{old}}}(\theta) = \mathbb{E}_{s, a \sim \pi_{\text{old}}}\!\left[\frac{\pi_\theta(a \mid s)}{\pi_{\theta_{\text{old}}}(a \mid s)} A^{\pi_{\text{old}}}(s, a)\right]

penalty 형식 maxLCDKL\max L - C \cdot D_{\text{KL}}과 비교하면, constraint 형식은 δ\delta가 “매 step에서 허용되는 KL 예산”이라는 직관적 의미를 갖는다. penalty의 CCLL과 KL의 스케일 비율에 따라 달라지므로 설정이 까다롭다.

트레이드오프

Constraint 형식은 δ\delta가 hyperparameter로서 직관적이고 매 step의 KL이 예측 가능하다는 장점이 있다. 단점은 표준 SGD로 직접 풀 수 없어 Lagrangian dual과 second-order 계산이 필요하다는 것이다. PPO는 이 비용을 피하기 위해 clipping으로 trust region을 근사한다.

Natural Policy Gradient로의 환원

TRPO를 직접 풀기 위해 Δθ=θθold\Delta\theta = \theta - \theta_{\text{old}}가 작다고 가정하고 Taylor 전개한다. KL의 0차·1차 항은 0이므로:

DˉKL(θoldθ)12ΔθF(θold)Δθ\bar{D}_{\text{KL}}(\theta_{\text{old}} \| \theta) \approx \frac{1}{2} \Delta\theta^\top F(\theta_{\text{old}}) \Delta\theta

여기서 FF는 Fisher information matrix다:

F(θ)=Es,aπθ[θlogπθ(as)  θlogπθ(as)]F(\theta) = \mathbb{E}_{s, a \sim \pi_\theta}[\nabla_\theta \log \pi_\theta(a \mid s) \; \nabla_\theta \log \pi_\theta(a \mid s)^\top]
정리 1 · TRPO Closed-form Update

Quadratic approximation 하에서 TRPO의 최적 업데이트는 다음과 같다:

Δθ=2δgF1gF1g\Delta\theta^* = \sqrt{\frac{2\delta}{g^\top F^{-1} g}} \cdot F^{-1} g

여기서 g=θLθold(θ)θoldg = \nabla_\theta L_{\theta_{\text{old}}}(\theta)\big|_{\theta_{\text{old}}}.

▷ 증명

Lagrangian L=gΔθλ(12ΔθFΔθδ)\mathcal{L} = g^\top \Delta\theta - \lambda(\frac{1}{2}\Delta\theta^\top F \Delta\theta - \delta)의 정류 조건에서 Δθ=λ1F1g\Delta\theta = \lambda^{-1} F^{-1} g. KL constraint가 binding될 때 12λ2gF1g=δ\frac{1}{2\lambda^2} g^\top F^{-1} g = \delta이므로 λ=gF1g/2δ\lambda = \sqrt{g^\top F^{-1} g / 2\delta}. 대입하면 closed-form이 나온다. \square

이 결과의 핵심은 방향 F1gF^{-1}g가 Kakade(2002)의 Natural Policy Gradient와 동일하다는 것이다. TRPO는 Natural PG에 “KL 예산으로부터 자동으로 결정되는 step size”를 더한 알고리즘이다. RL의 두 흐름인 trust region 방법과 natural gradient가 같은 수식으로 통합된다.

F1gF^{-1}g를 어떻게 계산하는가

문제는 FRd×dF \in \mathbb{R}^{d \times d}의 명시적 계산이 불가능하다는 것이다. d=106d = 10^6인 소형 신경망에서도 FF를 저장하는 데 1TB가 필요하고, 역행렬 계산은 O(d3)O(d^3)이다.

**Conjugate Gradient(CG)**가 이를 우회한다. CG는 Fx=gFx = gFF의 명시적 표현 없이 **FF와 벡터의 곱 FvFv**만으로 반복적으로 풀 수 있다. 그리고 FvFv 자체는 Pearlmutter trick으로 autograd만으로 계산된다:

Fv=θ ⁣[(θDˉKL)v]θoldFv = \nabla_\theta\!\left[(\nabla_\theta \bar{D}_{\text{KL}})^\top v\right]\bigg|_{\theta_{\text{old}}}

두 번의 backward pass로 FvFv를 얻는다. CG 10회 반복이면 충분한 이유는 신경망의 Fisher matrix가 사실상 낮은 effective rank를 가지기 때문이다 — 대부분의 고유값이 작고 소수만 지배적이므로 CG가 이 주요 방향들을 빠르게 수렴한다.

def fisher_vec_prod(pi, pi_old, v, damping=0.1):
    kl = mean_kl(pi_old, pi)
    grads = torch.autograd.grad(kl, list(pi.parameters()), create_graph=True)
    flat_grad = torch.cat([g.flatten() for g in grads])
    inner = (flat_grad * v).sum()
    Hv = torch.autograd.grad(inner, list(pi.parameters()))
    return torch.cat([h.flatten() for h in Hv]) + damping * v

damping * v 항은 F+λIF + \lambda I 정규화로, singular Fisher에 대한 수치적 안정성을 제공한다.

Line Search — 근사가 깨지는 곳을 막아라

Closed-form update는 quadratic KL 근사에 기반한다. Δθ\Delta\theta가 크면 실제 KL이 근사값을 벗어날 수 있다. 그래서 TRPO는 후보 step Δθ0\Delta\theta_0를 받아 backtracking으로 두 조건을 동시에 만족하는 가장 큰 step을 찾는다:

  1. DˉKL(πoldπnew)δ\bar{D}_{\text{KL}}(\pi_{\text{old}} \| \pi_{\text{new}}) \leq \delta
  2. Lθold(θnew)>Lθold(θold)L_{\theta_{\text{old}}}(\theta_{\text{new}}) > L_{\theta_{\text{old}}}(\theta_{\text{old}})
for j = 0, 1, ..., 9:
    Δθ_j = (0.5)^j · Δθ_0
    if mean_KL(π_old, π_new) ≤ δ AND L(π_new) > L(π_old):
        return θ_old + Δθ_j
return θ_old   # fallback

정리

  • KL divergence는 parameterization에 독립적인 정책 거리 측도다. Euclidean 거리로는 같은 정책 변화가 서로 다른 크기로 측정된다.
  • TRPO의 최적 update 방향은 Natural Policy Gradient F1gF^{-1}g와 동일하다. step size만 KL 예산 δ\delta로부터 자동 결정된다.
  • F1gF^{-1}g의 직접 계산은 불가능하다. CG + Pearlmutter trick이 O(d)O(d) 메모리와 O(dNiter)O(d \cdot N_{\text{iter}}) 시간으로 이를 근사한다.
  • Line search는 quadratic 근사가 무너지는 큰 step을 막는 마지막 안전망이다.

TRPO가 PPO에 의해 실전에서 대체된 것은 이 네 단계의 연산 비용 때문이다. 그러나 “KL을 예산으로 쓰는 trust region”이라는 설계 철학은 PPO의 clipping, SAC의 soft policy update, CPO의 안전 제약까지 이어진다. TRPO를 이해하는 것은 현대 policy optimization의 공통 언어를 배우는 것이다.

REF
Schulman et al. · 2015 · Trust Region Policy Optimization · ICML
REF
Kakade, S. · 2002 · A Natural Policy Gradient · NeurIPS