멀티헤드 어텐션(Multi-Head Attention)은 Transformer의 핵심 메커니즘으로, 여러 어텐션 헤드가 병렬로 서로 다른 표현 공간에서 어텐션을 계산한다. 단일 어텐션보다 다양한 관계 패턴을 포착한다.
단일 어텐션 (Scaled Dot-Product Attention)
Attention(Q, K, V) = softmax(QK^T / √dk) V
Q(Query): "무엇을 찾고 있나?"
K(Key): "나는 이런 정보를 갖고 있다"
V(Value): "실제 정보 내용"
1/√dk: 내적 값이 커지면 softmax 포화 방지를 위한 스케일링
멀티헤드 어텐션
MultiHead(Q,K,V) = Concat(head_1, ..., head_h) W^O
각 헤드 i:
head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)
h=8개 헤드가 병렬로 서로 다른 선형 투영 사용
직관적 이해
문장: "The cat sat on the mat"
헤드 1: 구문적 관계 포착 (cat ↔ sat)
헤드 2: 의미적 관계 포착 (sat ↔ mat, on ↔ mat)
헤드 3: 장거리 의존성 포착
헤드 4: 대명사 해소 포착
...
각 헤드가 다른 "측면"을 학습
python
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, n_heads=8):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
B = Q.size(0)
# 선형 투영 + 헤드 분할
Q = self.W_q(Q).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
# 어텐션 계산
scores = (Q @ K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores.masked_fill_(mask == 0, -1e9)
attn = scores.softmax(-1)
out = (attn @ V).transpose(1, 2).contiguous().view(B, -1, self.n_heads * self.d_k)
return self.W_o(out)
관련 개념
- •Transformer — 멀티헤드 어텐션이 핵심인 아키텍처
- •BERT — 양방향 멀티헤드 어텐션
- •GPT — 인과적(masked) 멀티헤드 어텐션
- •임베딩 — 어텐션 입력 Q, K, V의 기반