근위 정책 최적화(Proximal Policy Optimization, PPO)는 OpenAI가 2017년 제안한 강화학습 알고리즘이다. TRPO(Trust Region Policy Optimization)의 복잡성을 줄이면서도 안정적인 정책 업데이트를 달성한다. ChatGPT의 RLHF 학습에 활용된 알고리즘으로 유명하다.
핵심 아이디어
정책 업데이트 시 새로운 정책이 이전 정책에서 너무 멀어지지 않도록 클리핑(clipping)으로 제한한다.
$$L^{CLIP}(\theta) = \hat{E}_t\left[\min\left(r_t(\theta)\hat{A}_t,; \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t\right)\right]$$
- •$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$ (확률 비율)
- •$\hat{A}_t$: 이점 함수 추정값
- •$\epsilon$: 클리핑 범위 (보통 0.1~0.2)
구현
python
import torch
import torch.nn as nn
import numpy as np
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden=64):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden), nn.Tanh(),
nn.Linear(hidden, hidden), nn.Tanh()
)
self.actor = nn.Linear(hidden, action_dim)
self.critic = nn.Linear(hidden, 1)
def forward(self, x):
h = self.shared(x)
return self.actor(h), self.critic(h)
def get_action(self, state):
logits, value = self(state)
dist = torch.distributions.Categorical(logits=logits)
action = dist.sample()
return action.item(), dist.log_prob(action), value
def ppo_update(model, optimizer, states, actions, old_log_probs,
returns, advantages, clip_eps=0.2, n_epochs=4):
for _ in range(n_epochs):
logits, values = model(states)
dist = torch.distributions.Categorical(logits=logits)
new_log_probs = dist.log_prob(actions)
entropy = dist.entropy().mean()
# 정책 손실 (클리핑)
ratio = (new_log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
# 가치 함수 손실
value_loss = (returns - values.squeeze()).pow(2).mean()
# 엔트로피 보너스 (탐험 장려)
loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
PPO vs 다른 알고리즘
| 알고리즘 | 안정성 | 샘플 효율 | 구현 난이도 |
|---|
| REINFORCE | 낮음 | 낮음 | 쉬움 |
| A3C | 중간 | 중간 | 중간 |
| TRPO | 높음 | 중간 | 어려움 |
| PPO | 높음 | 중간 | 쉬움 |
| SAC | 높음 | 높음 | 중간 |
관련 개념