1. 개념 요약
- 데이터에 대한 라벨을 사전에 구축하지 않아도, self supervised learning을 통해 Representation Learning을 수행
2. 기본 아이디어
여러 입력 쌍에 대해 유사도를 라벨로 판별 모델을 학습함. (데이터 포인트 간의 유사성과 차이를 학습)
즉, 학습된 표현 공간 상에서 비슷한 데이터는 가까이, 다른 데이터는 멀리 위치하도록 학습하는 방법.
- 이때, GT label y가 0(similar) 아니면 1(dissimilar)임. (only two types)
- GT가 similar인 경우와 dissimilar인 경우 각각에 적용되는 loss function의 종류가 서로 다름 (L_s, L_d)
- 학습을 통해, 모델은 각 데이터 포인트를 잘 구별할 수 있는 특성을 파악해내고, 데이터 표현을 향상시키는 데 도움을 줌
3. 이점:
- self-supervised인 만큼 비용이 적음 (단, 아예 비지도 학습인 생성 모델을 활용하는 경우도 구축 비용이 적음.
라벨링 데이터를 활용하는 판별 모델보다는 구축 비용이 저렴함.)
- 보다 general한 feature representation & unseen class에도 대응 가능
- 다른 task로 fine tuning할 때 모델 구조 수정 없이 이뤄질 수 있어서 간편함.
4. 기본 원리
1) Data Augmentation
- positive pair & negative pair 생성 ((늘 그렇듯^^) negative sampling이 주요 관건)
2) Generating Representation (=Feature Extraction)
- 주로 pre-trained backbone 활용 (ex. ResNet)
3) Projection Head (=Obtain Metric Embedding)
- 2에서 얻은 extracted feature embedding을 metric space로 project시켜줌
- 이를 위해 간단한 MLP가 사용됨
- 이 Head 자체를 통해 metric learning이 진행되는 건 아니고, 앞단의 embedding output의 차원을 축소하는 역할을 함.
- 해당 head의 output이 이후 metric learning을 위한 contrastive metric loss 계산의 대상이 됨
4) Contrastive Loss 계산
- Loss function에 여러 종류가 있는데, 그 중 InfoNCE, NT-Xent 등이 있음
- 아까 언급했듯, GT가 similar인 경우와 dissimilar인 경우 적용되는 loss function의 종류가 서로 다름
4. 원리 조금 더 구체적으로
Input sample X_i 각각에 대해 이하 과정을 진행
1) prior knowledge를 바탕으로, X_i와 비슷한 것들을 모두 S_{xi}에 넣음
2) S_{xi}에 해당하는 모든 X에 대해서는 Y_{ij}의 라벨을 0으로 지정, 해당하지 않을 경우 라벨을 1로 지정
→ 마치 labeled training set처럼, (X_i, X_j, Y_{ij})과 같은 pair로 만듦
→ 실제 train을 진행할 때, Y_{ij}=0인 경우에는 W가 D_w = D(ij)를 줄이도록 학습하고, vice versa
5. Application
5-1. SimCLR
loss function으로 NT-Xent를 사용함 (분자: L_s, 분모: L_d)
이 NT-Xent를 통해 contrastive learning이 이루어지는 것.
5-2. NCE
SIMCLR은 데이터 포인트 간 유사성을 비교하고 구조화된 임베딩을 학습하는 반면,
NCE는 비교적 단순한 분류 모델(binary classifier)을 활용하여 데이터 포인트를 분류하며 학습
특히, 확률분포 개념을 이용함으로써 task가 더 간단해짐.
- M개의 training examples가 있다고 할 때
→ True pairs: centor word + another word appeared in context (즉 함께 get along하는 두 단어)
→ Fake pairs: 랜덤하게 아무 두 단어로 구성된 pair
- "True pairs가 어떠한 확률분포로부터 샘플링되었나?"
: 그것을 p_m이라고 가정하고, 우리의 목표는 p_m을 예측하는 것.
- 반면, Fake pairs에 대한 확률분포 p_n을 구해둠. (이건 상수같은 느낌, 학습 대상이 아님)
- 모델이 학습해야 할 task
: True pairs와 Fake pairs가 섞여있는 와중에, 각 샘플들이 p_m에서 왔는지 p_n에서 왔는지 binary classification
→ p_m에서 true pairs가 나올 확률을 maximize
→ p_m에서 fake pairs가 나올 확률을 minimize
reference: https://velog.io/@yjkim0520/Contrastive-Learning
https://cedar.buffalo.edu/~srihari/CSE676/18.6%20Noise%20Contrastive%20Estimation.pdf
'AI > Data Science' 카테고리의 다른 글
[선형대수학] np.linalg.eigh 내림차순 정렬 (0) | 2023.10.05 |
---|---|
[sql] where vs having (0) | 2023.08.31 |
폴더 내 파일 리스트/경로 리스트 생성 (0) | 2023.08.13 |
Batch Normalization vs Layer Normalization (0) | 2023.08.12 |
[Tensorflow] Callback 콜백함수란? (0) | 2023.08.03 |