설명 가능 AI(Explainable AI, XAI)는 AI 모델의 예측 과정과 결과를 인간이 이해할 수 있도록 설명하는 기술과 방법론이다. 의료, 금융, 법률 등 고위험 분야에서의 AI 도입에 필수적이다.
설명 유형
전역적(Global) 설명: 모델 전체 동작 이해
└── 특성 중요도, 부분 의존성 그래프(PDP)
지역적(Local) 설명: 개별 예측 이유 설명
└── LIME, SHAP, Attention 시각화
SHAP (SHapley Additive exPlanations)
python
import shap
import xgboost as xgb
from sklearn.datasets import load_boston
import numpy as np
# 모델 학습
X, y = load_boston(return_X_y=True)
model = xgb.XGBRegressor(n_estimators=100, max_depth=4)
model.fit(X, y)
# SHAP 값 계산
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
# 특성 중요도 요약
shap.summary_plot(shap_values, X, show=False)
# 개별 예측 설명 (force plot)
shap.force_plot(explainer.expected_value, shap_values[0], X[0])
# SHAP 값 해석: 양수 = 예측값 증가 기여, 음수 = 감소 기여
print("특성별 평균 |SHAP|:", np.abs(shap_values).mean(axis=0))
LIME (Local Interpretable Model-Agnostic Explanations)
python
from lime import lime_tabular
explainer = lime_tabular.LimeTabularExplainer(
X_train,
feature_names=feature_names,
class_names=['낮음', '높음'],
mode='classification'
)
# 개별 샘플 설명
exp = explainer.explain_instance(
X_test[0],
model.predict_proba,
num_features=10
)
exp.show_in_notebook()
주요 XAI 기법 비교
| 기법 | 유형 | 적용 대상 | 특징 |
|---|
| SHAP | 지역/전역 | 트리, 딥러닝 | 수학적 보장, 느림 |
| LIME | 지역 | 모델 불가지론 | 빠름, 불안정성 |
| Grad-CAM | 지역 | CNN | 시각적 히트맵 |
| Attention | 지역 | 트랜스포머 | 내장, 해석 논란 |
| PDP | 전역 | 모델 불가지론 | 특성 간 상호작용 무시 |
Grad-CAM
python
import torch
import torch.nn.functional as F
def grad_cam(model, image, target_layer, class_idx):
"""CNN의 마지막 합성곱 레이어 기울기로 중요 영역 시각화"""
activations = []
gradients = []
def fwd_hook(module, inp, out): activations.append(out)
def bwd_hook(module, gin, gout): gradients.append(gout[0])
h1 = target_layer.register_forward_hook(fwd_hook)
h2 = target_layer.register_backward_hook(bwd_hook)
output = model(image)
model.zero_grad()
output[0, class_idx].backward()
h1.remove(); h2.remove()
act = activations[0].squeeze() # (C, H, W)
grad = gradients[0].squeeze() # (C, H, W)
weights = grad.mean(dim=(1, 2)) # 채널별 평균 기울기
cam = (weights[:, None, None] * act).sum(0)
cam = F.relu(cam)
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam.detach().cpu().numpy()
관련 개념