배치 정규화(Batch Normalization, BN)는 각 미니배치에서 레이어 입력을 정규화해 학습을 안정화하고 가속하는 기법이다. 2015년 Ioffe와 Szegedy가 제안했으며, 높은 학습률 사용, 가중치 초기화 민감도 감소, 정규화 효과를 제공한다.
수식
미니배치 $\mathcal{B} = {x_1, ..., x_m}$에 대해:
$$\hat{x}i = \frac{x_i - \mu\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \epsilon}}$$
$$y_i = \gamma \hat{x}_i + \beta$$
- •$\mu_\mathcal{B}$: 미니배치 평균
- •$\sigma^2_\mathcal{B}$: 미니배치 분산
- •$\gamma, \beta$: 학습 가능한 스케일/시프트 파라미터
구현
python
import torch
import torch.nn as nn
class BatchNorm1D(nn.Module):
"""교육용 배치 정규화 구현"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.eps = eps
self.momentum = momentum
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
# 추론용 이동 평균
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, x):
if self.training:
mean = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
# 이동 평균 업데이트
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
else:
mean = self.running_mean
var = self.running_var
x_hat = (x - mean) / (var + self.eps).sqrt()
return self.gamma * x_hat + self.beta
# PyTorch 내장 사용
bn = nn.BatchNorm2d(64) # 채널 수
x = torch.randn(16, 64, 28, 28)
print(bn(x).shape) # (16, 64, 28, 28)
정규화 기법 비교
| 기법 | 정규화 축 | 적합 상황 |
|---|
| Batch Norm | 배치 방향 | CNN, 큰 배치 |
| Layer Norm | 특성 방향 | 트랜스포머, NLP |
| Instance Norm | 공간 방향 | 스타일 트랜스퍼 |
| Group Norm | 그룹 방향 | 소배치 |
| Weight Norm | 가중치 재파라미터화 | RNN |
학습/추론 차이
python
model.train() # 미니배치 통계 사용
model.eval() # 이동 평균 통계 사용
# → eval() 호출 필수!
관련 개념