Vision Transformer(ViT)는 이미지를 고정 크기 패치(patch)로 분할해 트랜스포머에 입력하는 모델이다. 2020년 Google Brain의 Dosovitskiy 등이 제안했으며, 충분한 데이터로 학습 시 CNN을 능가하는 성능을 보인다.
핵심 구조
이미지 (H×W×C)
│
├── 패치 분할: (H/P × W/P)개의 P×P 패치
│
├── 선형 임베딩: 각 패치 → D차원 벡터
│
├── [CLS] 토큰 추가 + 위치 임베딩
│
└── Transformer Encoder (L개 레이어)
├── Multi-Head Self-Attention
└── MLP
출력: [CLS] 토큰 → 분류 헤드
구현
python
import torch
import torch.nn as nn
import math
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.n_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, embed_dim, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, n_patches, embed_dim)
return x
class ViT(nn.Module):
def __init__(self, img_size=224, patch_size=16, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
n_patches = self.patch_embed.n_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads,
dim_feedforward=int(embed_dim * mlp_ratio),
batch_first=True, norm_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x) # (B, N, D)
cls = self.cls_token.expand(B, -1, -1) # (B, 1, D)
x = torch.cat([cls, x], dim=1) # (B, N+1, D)
x = x + self.pos_embed
x = self.transformer(x)
x = self.norm(x[:, 0]) # CLS 토큰
return self.head(x)
model = ViT()
print("파라미터 수:", sum(p.numel() for p in model.parameters()) // 1e6, "M")
ViT 변형 비교
| 모델 | 패치 | 레이어 | 헤드 | 임베딩 | 파라미터 |
|---|
| ViT-S/16 | 16 | 12 | 6 | 384 | 22M |
| ViT-B/16 | 16 | 12 | 12 | 768 | 86M |
| ViT-L/16 | 16 | 24 | 16 | 1024 | 307M |
| ViT-H/14 | 14 | 32 | 16 | 1280 | 632M |
CNN vs ViT
| 항목 | CNN | ViT |
|---|
| 귀납 편향 | 강함 (지역성, 등변성) | 약함 |
| 소량 데이터 | 강점 | 약점 |
| 대량 데이터 | 좋음 | 더 좋음 |
| 계산 복잡도 | O(n) | O(n²) (어텐션) |
관련 개념