FixMatch : Simplifying Semi-Supervised Learning with Consistency and Confidence 논문요약

이미지 출처: http://dsba.korea.ac.kr/seminar/?mod=document&uid=248

https://amitness.com/2021/07/semi-supervised-learning/

https://amitness.com/2020/03/fixmatch-semi-supervised/


Background

Pseudo-label(a.k.a Self-training)

image

  • labeled data로 모델 선학습
  • 학습된 모델로 unlabeled image 예측
  • 예측 결과 중 max confidence를 가지는 클래스를 pseudo label로 지정한다
  • Pseudo-label과 예측결과를 가지고 cross-entropy loss를 계산한다

Consistency Regularization

image image

  • Main Idea : unlabeled data에 대하여, 변형된 이미지와 그것의 원본 이미지는 어떤 모델에 의해 같은 예측결과가 나와야 한다.
  • 두 변형된 이미지의 확률값의 차이를 목적함수에 반영하여 temporary label의 신뢰성을 향상시키고자 함.

FixMatch

image

1 ) For Labeled data, weakly-augmentation(ex.flip-and-shift) & compute standard cross-entropy loss

2) For Unlabeled data,

2-1) compute the model's predicted class distribution given a weakly-augmented version

2-2) Get pseudo-label of weakly-augmented version

2-3) compute the model's predicted class distribution given a strongly-augmented version

2-4) Apply Cross Entropy with the predicted value and the pseudo label

image

  • lambda_u = 1로 세팅한다. → 이전 논문들은 학습시킬수록 람다값을 증가시켰지만, FixMatch에서는 thresholding이 그 역할을 해줌. → 학습 초기에는 unlabeled data에 대한 예측결과가 threshold 이하이므로 자동으로 labeled data에 대해서만 학습됨. → 학습이 진행될수록 threshold를 넘어가 loss에 l_u가 포함된다.

Strong-Augmentation method

  • AutoAugment : best accuracy를 가지는 최적의 augmentation을 찾는 강화학습 기반의 파이썬 라이브러리 → labeled data로 미리 학습시키지 않아도 되는 augmentation 방법만을 채택
    • RandAugment:
      • 랜덤하게 N개의 augmentation선택 후, random magnitude M을 선택. M의 비율만큼 각 augmentation의 magnitude를 적용
      • Stochastic하게(매 training step마다) 랜덤하게 적용해줌.
    • CTAugment:
      • magnitude val를 bin으로 나누어주고, 각 bin에게 가중치 부여.
      • two transformations를 랜덤하게 선택 후, 각 transformation에 대해 가중치가 부여된 magnitude bin을 랜덤하게 선택.
      • 가중치를 갱신하기 위해, labeled data를 두 transformation에 통과시킴.
      • 예측결과와 실제 label을 비교하여, bin weight를 갱신시킴.
      • 더 좋은 예측을 하는 모델이 고른 augmentation 기법을 선택하게 됨.
  • Cutout

Experiments

image

  • Labeled dataset을 각 클래스에서 한 장씩만 뽑아서 구성했을 때, representative 이미지로 이루어진 데이터셋이 가장 정확도가 높았음.

image

  • CIFAR-10과 SVHN 데이터셋에 대해서 좋은 결과를 보임
  • CIFAR-100에서는 ReMixMatch가 더 성능이 좋음 → ReMixMatch의 DA(Distribution Alignment)를 FixMatch에 추가해보았더니 에러율이 40.14% 나옴

image

  • μ(= 한 배치 내에서 unlabeled data 비율)이 클수록 에러율이 감소함.
    • η(=learning rate)을 배치사이즈에 따라 조절했더니 더 효과적이었음.
  • threshold는 비교적 큰 값을 사용할 때 더 좋았음.
  • Sharpening에서는 threshold가 특별한 경향성을 띄지 않음.
    • sharpening을 하려면 hyperparemeter tuning이 더 필요할 것으로 보임.
  • low-label regime에서는 weight decay를 잘 고르는 것도 중요함.

Comparison

image

공통점 :

  • data augmentation 활용
  • Unlabeled data에 대해서 guessed label을 형성

MixMatch

  • labeled & unlabeled data에 대해 weak augmentation
  • Sharpening 기법(Entropy Minimization)
  • MixUp

ReMixMatch

  • labeled & unlabeled data에 대해 strong augmentation
  • weakly augmented unlabeled data의 predictions을 guessed label로 선정.
  • Distribution Alignment
  • Augmentation Anchoring
  • CTAugment

FixMatch

  • Unlabeled data에 대해 weak & strong augmentation
  • weak augmented data의 predictions → Pseudo labelling(= One hot vector)
  • Consistency Regularization
  • CTAugment, RandAugment, Cutout ​1
박나깨

박나깨

저는 Deep Learning, Computer Vision, AI, Image Processing에 관심이 있는 학생입니다.