Improved Training of WGAN(WGAN-GP) 논문 요약
✓ 이 글은 다음 포스팅을 참조하였음
https://leechamin.tistory.com/232
https://jonathan-hui.medium.com/gan-wasserstein-gan-wgan-gp-6a1a2aa1b490
https://haawron.tistory.com/21
지난번 WGAN 포스팅 말미에서 WGAN-GP에 대해서 살짝 언급을 한 적이 있다.
해당 논문에서 그 주제를 다루고 있는데, 자세하게 읽진 않을 것이고 간단한 구조 설명 및 실험결과 정도만 보고 넘어가려고 한다.
우선은 기존 WGAN를 살짝 집고 넘어가자면, 기존 JS나 KL divergence가 아닌 EM distance방식을 사용하였고,
그 distance를 계산하기 위해서 w라는 새로운 변수를 잡았었다. 이 w라는 변수는 W라는 공간에 존재하는데, 이 공간은 compact해야된다는 성질이 있고,
compact하다는 것은 모든 함수(즉, 여기서는 discriminator)가 1-립시츠라는 조건을 만족해야했다.
그걸 지키기 위해서 다소 무식한 방법인 weight clipping을 사용했다.
여기서 1-립시츠 조건이란, 쉽게 말하면 모든 곳에서 기울기의 norm값이 1이하를 가져야 되는 것을 말한다.
아무튼 자세한 내용은 저번 포스팅을 참고.
그런데 저 방법이 다소 무식하긴 하지만, 어쨋든 좋은 효과를 내보였기 때문에..저자는 그냥 그 방법을 채택했었다.
그러나 왜 이 방법을 결국 버리고 새로운 방식을 선택을 한걸까? 저자는 그 이유를 optimal critic의 성질로부터 찾고 있다.
그것을 수학적으로 증명해낸 것이 바로 Proposition 1 과 Corollary 1이다.
굉장히 복잡하지만, 정리하자면 결국 이거다.
max 어쩌구 식의 최적해인 1-립시츠 함수를 f*라고 할때,
내분점 x_t는 다음과 같고 그때의 f*의 기울기의 norm값은 1이다.
즉 최적의 f의 성질이 다음과 같으므로, 기존 f_w에다가 그 성질을 추가하고, f*에 근사하도록 학습시키는 것이다.
이것이 WGAN-GP의 motivation이라고 할 수 있다.
그리고 무엇보다 weight clipping은 optimize가 힘들다는 문제가 있다.
이를 실험을 통해 증명해냈는데,
이 그림을 통해 설명하고 있다. 다소 난해해 보이지만 이해하면 쉽다.
우선 주황색 굵은 선은 어떤 data distribution이라고 생각하면 된다. 그리고 실선은 WGAN critic이 그 dataset에 대해서 optimal이라고 생각하고 학습한 value surface이다.
자세히보면 weight clipping은 다소 고차원적인? 좀 디테일적인 부분을 잡지 못하고 있는 것을 알 수 있다. 반면에 새로운 방법은 보다 디테일하게 optimization을 해낸 것을 알 수 있다.
다음 실험은 기울기변화와 weight의 분포를 관찰한 실험이다.
좌측그림에서는 weight clipping에서 gradient vanishing, exploding문제가 발생하고 있음을 알 수 있다.
그리고 우측그림에서는, weight clipping은 -0.01과 0.01에 weight가 거의 몰려있는 상황이다.
결과적으로 새로 정립된 Loss term은 다음과 같다.
여기서 x_hat은 실제 데이터분포 Pr에서 샘플링된 점들과, generator 분포 Pg에서 샘플링된 점들 사이에 생성된 직선위에 있는 내분점을 말한다.
새로 붙혀진 gradient penalty의 의미는 “gradient norm이 1로부터 떨어져 있을수록 penalty가 증가”한다는 의미로 생각하면 된다.
(참고로 여기서 𝜆=10 사용)
그리고 또다른 핵심은, 여기서는 batch normalization을 사용하지 않기로 하였다.
왜냐하면 batch norm은 Discriminator의 single input/output의 매핑 문제를 배치 전체에 대한 input/output 매핑문제로 바꿔버린다.
그러나 gradient penalty에서는 이 세팅이 필요가 없는게, 여기서는 한 input마다 독립적으로 gradient penalty를 먹이고 있다. 따라서 input간의 상관관계를 없애줘야 하기 때문에 batch norm대신에 layer normalization을 사용하게 되었다고 한다.
아 그리고 원래 기존 WGAN에서는 RMSProp을 사용하였다.
Adam을 사용할수 없었던 이유는, momentum으로 인한 불안정성이 너무 심했기 때문인데,
여기서는 Adam을 쓰니까 오히려 성능이 더 증가하였다. (정확한 문제해결방법이 뭐였는지 모르겠지만..)
실험결과를 보면, 결국엔 다른 모델들은 mode collapse에 빠졌거나 불안정한 결과가 나온 반면에, WGAN-GP는 좋은 결과가 나온 것을 볼 수 있다.
이외에 여러 다른 실험들은 논문을 참고하길 바란다.