FixMatch: 적은 label에도 성능 올리기
🖼️

FixMatch: 적은 label에도 성능 올리기

Tags
Computer Vision
Machine Learning
Published
December 9, 2022
Author
유레미 EUREMI
현재 회사에선 label의 신뢰성이 낮은 데이터를 다루고 있습니다. 이를 수작업으로 거를 수는 없어 사람의 리소스가 적으면서도 성능을 올리기 위한 방법에 대해 고민했고 그 중 발견한 논문입니다. NeurIPS 2020에서 Google Research가 발표한 FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence에 대해 정리한 글입니다. 고려대학교 산업경영학부 DSBA 연구실 Lab Seminar 유튜브 영상을 참고하였습니다.
 

Introduction

deep network는 supervised learning에 좋은 성능을 보입니다. 방대한 양의 데이터셋을 사용하여 성능을 올리기엔 labeling이 필요하고 이를 위해선 사람의 노동력이 필요합니다. 또한, 전문가(예를 들면, 의료 분야의 의사)가 labeling을 수행하는 경우 극심한 비용이 들 수 있습니다. 이를 해결하기 위해 semi-supervised learning(SSL) 기법을 많이 연구하고 있습니다. SSL은 label이 없는 데이터를 학습하는 접근법입니다. 일반적으로 label이 없는 이미지에 pseudo label을 생성하고 label이 없는 이미지를 pseudo label로 예측하도록 모델을 훈련시키는 기법을 사용합니다.
 

Model Structure

notion image
label이 주어진 데이터로 학습하고 label이 없는 데이터로 검증하여 예측한 값을 pseudo-labeling이라 합니다. 아래의 그림으로 이해해보겠습니다. 1. label이 있는 데이터를 이용해 모델을 학습합니다. 2. 학습된 모델로 unlabeled data를 예측합니다. 이를 통해 예측된 값이 pseudo label이 되는 것이죠. 3. 마지막으로 label이 없으나 pseudo label이 생긴 데이터와 label이 있는 데이터를 합쳐 재학습하게 됩니다.
notion image
fixmatch는 consistency regularization과 pseudo-labeling을 이용하여 pseudo label을 만들어냅니다. 그 중에서도 weakly-augmented unlabeled image를 이용하여 strong-augmented unlabeled image의 pseudo label을 만들어냅니다. fixmatch를 CIFAR-10을 이용하여 250개의 label이 있는 데이터를 가지고서도 SOTA를 찍어냈습니다.
 

Pseudo-labeling

labeling된 데이터를 이용해 모델을 학습한 뒤 학습한 모델로 labeling되어있지 않은 unlabeled 데이터를 예측합니다. 예측된 값들 중 confidence score가 높은 데이터들만 pseudo label을 부여합니다. 예측된 label은 실제 label이 아니기 때문에 pseudo label이라 합니다. labeled 데이터와 unlabeled 데이터에 pseudo label을 붙여서 학습한 뒤 또 다시 labeling 되어있지 않은 나머지 unlabeled 데이터를 학습합니다. 계속해서 반복하여 unlabeled data를 줄여나갑니다.
notion image
 

Consistency regularization

데이터에 작은 변형(weakly augmentation)을 가해도 예측한 확률은 변하지 않을 것이라는 가설 하에 unlabeled 데이터에 noise(augmentation)을 주입한 뒤 noise가 없는 데이터와 noise가 주입된 데이터를 동일한 class 분포로 예측하도록 학습하는 기법입니다.
notion image
 

Entropy Minimization

softmax input에 temperature(T)란 hyperparameter를 적용해 보다 더 확실한 예측 확률을 만들어냅니다. 코드를 보시면 temperature를 적용한 값들의 더 극단적인 값을 갖는 것을 보실 수 있습니다.
notion image
import torchimport torch.nn as nn inputs = torch.randn(2, 3) m = nn.Softmax(dim=1)T = 0.3 print('#inputs\\n', inputs) print('\\n#softmax(inputs)\\n', m(inputs)) print('\\n#softmax(inputs)+temperature\\n', m(inputs/T)) # #inputs # tensor([[ 1.4486, -0.0243, -0.1175], # [ 0.3694, 0.8460, -0.0310]]) # #softmax(inputs) # tensor([[0.6954, 0.1594, 0.1452], # [0.3048, 0.4909, 0.2042]]) # #softmax(inputs)+temperature # tensor([[0.9874, 0.0073, 0.0053], # [0.1623, 0.7950, 0.0427]])
 

loss function

  • x: label이 있는 이미지들 집합(b개수만큼)
  • p(y∥x): x를 model을 이용해 class를 예측한 확률
  • A(.): strongly augmentation
  • α(.): weakly augmentation
  • p: one-hot label들의 확률 분포
  • H(p, q): p와 q의 확률 분포의 cross-entropy
notion image
supervised loss ls는 일반적으로 사용되는 supervised loss이며 weakly-augmented 데이터에 사용됩니다. 위의 내용을 바탕으로 설명드리면 label이 있는 이미지들에 약한 augmentation을 가한 이미지들이 y를 예측한 확률 값과 label의 확률 값의 cross-entropy loss입니다.
notion image
unsupervised loss lu는 strongly-augmented 데이터에 사용되는 loss입니다. 설명 드리자면 unlabeled 이미지에 강한 augmentaion을 가한 y를 예측한 확률 값과 약한 augmentation을 가한 y를 예측한 확률 값의 cross-entropy loss에 max(qb)를 곱해주는데 max(qb)가 threshold(τ) 이상인 경우만 곱해줍니다.
notion image
fixmatch의 최종 loss는 아래와 같습니다. supervised loss와 unsupervised loss에 가중치 람다를 곱한 합입니다.
notion image
 

Experiment & Results

  • weight decay regularization
  • standard SGD with momentum
  • cosine learning rate decay, k는 현재 training step, K는 전체 training step
  • hyperparameters: λu=1,η=0.03,β=0.9,τ=0.95,μ=7,B=64,K=220
  • backbone: wide resnet-28-2 with 1.5M parameters
notion image
 

Conclusion

적은 레이블링 데이터에도 높은 정확도를 높일 수 있는 비교적 간단한 알고리즘인 FixMatch 모델을 제안했고, weight decay와 optimizer 같은 설계가 중요하다는 것을 강조하였습니다. FixMatch와 같은 간단하면서도 높은 성능을 내는 SSL 알고리즘에 주목해야한다고 제안했습니다.
 

Reference