FixMatch : Simplifying Semi-Supervised Learning with Consistency and Confidence 논문요약
이미지 출처: http://dsba.korea.ac.kr/seminar/?mod=document&uid=248
https://amitness.com/2021/07/semi-supervised-learning/
https://amitness.com/2020/03/fixmatch-semi-supervised/
Background
Pseudo-label(a.k.a Self-training)
- labeled data로 모델 선학습
- 학습된 모델로 unlabeled image 예측
- 예측 결과 중 max confidence를 가지는 클래스를 pseudo label로 지정한다
- Pseudo-label과 예측결과를 가지고 cross-entropy loss를 계산한다
Consistency Regularization
- Main Idea : unlabeled data에 대하여, 변형된 이미지와 그것의 원본 이미지는 어떤 모델에 의해 같은 예측결과가 나와야 한다.
- 두 변형된 이미지의 확률값의 차이를 목적함수에 반영하여 temporary label의 신뢰성을 향상시키고자 함.
FixMatch
1 ) For Labeled data, weakly-augmentation(ex.flip-and-shift) & compute standard cross-entropy loss
2) For Unlabeled data,
2-1) compute the model's predicted class distribution given a weakly-augmented version
2-2) Get pseudo-label of weakly-augmented version
2-3) compute the model's predicted class distribution given a strongly-augmented version
2-4) Apply Cross Entropy with the predicted value and the pseudo label
- lambda_u = 1로 세팅한다. → 이전 논문들은 학습시킬수록 람다값을 증가시켰지만, FixMatch에서는 thresholding이 그 역할을 해줌. → 학습 초기에는 unlabeled data에 대한 예측결과가 threshold 이하이므로 자동으로 labeled data에 대해서만 학습됨. → 학습이 진행될수록 threshold를 넘어가 loss에 l_u가 포함된다.
Strong-Augmentation method
- AutoAugment : best accuracy를 가지는 최적의 augmentation을 찾는 강화학습 기반의 파이썬 라이브러리
→ labeled data로 미리 학습시키지 않아도 되는 augmentation 방법만을 채택
- RandAugment:
- 랜덤하게 N개의 augmentation선택 후, random magnitude M을 선택. M의 비율만큼 각 augmentation의 magnitude를 적용
- Stochastic하게(매 training step마다) 랜덤하게 적용해줌.
- CTAugment:
- magnitude val를 bin으로 나누어주고, 각 bin에게 가중치 부여.
- two transformations를 랜덤하게 선택 후, 각 transformation에 대해 가중치가 부여된 magnitude bin을 랜덤하게 선택.
- 가중치를 갱신하기 위해, labeled data를 두 transformation에 통과시킴.
- 예측결과와 실제 label을 비교하여, bin weight를 갱신시킴.
- 더 좋은 예측을 하는 모델이 고른 augmentation 기법을 선택하게 됨.
- RandAugment:
- Cutout
Experiments
- Labeled dataset을 각 클래스에서 한 장씩만 뽑아서 구성했을 때, representative 이미지로 이루어진 데이터셋이 가장 정확도가 높았음.
- CIFAR-10과 SVHN 데이터셋에 대해서 좋은 결과를 보임
- CIFAR-100에서는 ReMixMatch가 더 성능이 좋음 → ReMixMatch의 DA(Distribution Alignment)를 FixMatch에 추가해보았더니 에러율이 40.14% 나옴
- μ(= 한 배치 내에서 unlabeled data 비율)이 클수록 에러율이 감소함.
- η(=learning rate)을 배치사이즈에 따라 조절했더니 더 효과적이었음.
- threshold는 비교적 큰 값을 사용할 때 더 좋았음.
- Sharpening에서는 threshold가 특별한 경향성을 띄지 않음.
- sharpening을 하려면 hyperparemeter tuning이 더 필요할 것으로 보임.
- low-label regime에서는 weight decay를 잘 고르는 것도 중요함.
Comparison
공통점 :
- data augmentation 활용
- Unlabeled data에 대해서 guessed label을 형성
MixMatch
- labeled & unlabeled data에 대해 weak augmentation
- Sharpening 기법(Entropy Minimization)
- MixUp
ReMixMatch
- labeled & unlabeled data에 대해 strong augmentation
- weakly augmented unlabeled data의 predictions을 guessed label로 선정.
- Distribution Alignment
- Augmentation Anchoring
- CTAugment
FixMatch
- Unlabeled data에 대해 weak & strong augmentation
- weak augmented data의 predictions → Pseudo labelling(= One hot vector)
- Consistency Regularization
- CTAugment, RandAugment, Cutout 1