본문 바로가기

AI/Paper - Theory

[MoCo 논문 리뷰] - Momentum Contrast for Unsupervised Visual Representation Learning

반응형

*MoCo 논문 리뷰를 코드와 같이 분석한 글입니다! SSL 입문하시는 분들께 도움이 되길 원하며 궁금한 점은 댓글로 남겨주세요.

*SSL(Self-Supervised-Learning) 중 contrastive learning을 위주로 다룹니다!

 

Moco paper: [1911.05722] Momentum Contrast for Unsupervised Visual Representation Learning (arxiv.org)

 

Momentum Contrast for Unsupervised Visual Representation Learning

We present Momentum Contrast (MoCo) for unsupervised visual representation learning. From a perspective on contrastive learning as dictionary look-up, we build a dynamic dictionary with a queue and a moving-averaged encoder. This enables building a large a

arxiv.org

 

Moco github: GitHub - facebookresearch/moco: PyTorch implementation of MoCo: https://arxiv.org/abs/1911.05722

 

Momentum Contrast for Unsupervised Visual Representation Learning

We present Momentum Contrast (MoCo) for unsupervised visual representation learning. From a perspective on contrastive learning as dictionary look-up, we build a dynamic dictionary with a queue and a moving-averaged encoder. This enables building a large a

arxiv.org


Contents

1. Simple Introduction

2. Background Knowledge: SimCLR

3. Method

    - Shuffling Batch Normalization

    - Momentum method

    - How to update dictionary(queue)


Simple Introduction

MoCo

MoCo는 기존의 SimCLR와 거의 구조가 동일합니다.

두 개의 Batch, 그리고 각각 augementaion과 encoder에 넣어서 서로의 유사도 계산을 한다는 점이다.

 

그러나, MoCo의 한쪽 encoder를 보면 momentum encoder가 있다.

그리고 또 새로운 개념인 queue가 등장했다.

여기서는 queue가 Negative sample 역할을 수행하고, momentum encoder을 통해서 나오는 key는 positive sample의 역할을 수행한다.

 

이로써, negative sample에 대한 의존성을 줄였고, queue라는 것을 이용해서 효율적으로 memory를 이용할 수 있도록 하였다.

 

한번 좀 더 자세히 들어가보자!


(Background Knowledge: SimCLR)

 

SimCLR 논문 리뷰: https://kyujinpy.tistory.com/39

 

[SimCLR 논문 리뷰] - A Simple Framework for Contrastive Learning of Visual Representations

*SimCLR 논문 리뷰를 위한 글입니다! SSL 입문하시는 분들께 도움이 되길 원하며 궁금한 점은 댓글로 남겨주세요. *SSL(Self-Supervised-Learning) 중 contrastive learning을 위주로 다룹니다! *해당 글에서는 Proxy

kyujinpy.tistory.com

SimCLR의 한계점을 극복한 모델이라고 볼 수 있기 때문에, SimCLR을 공부하고 오시면 좋을 것이라고 생각이 듭니다.

 

즉, SimCLR의 본질적인 문제, 'Negative sample에 대한 의존성과 큰 batch size가 있어야만 좋은 성능이 나온다'는 것을 해결하였습니다.

 

어떻게 해결했는지 자세히 알아봅시다.


Method

 

MoCo에서는 총 3가지 구조를 보여주면서 서로 비교를 했다.

 

우리가 가장 주의 깊게 봐야할 것은 c)번, MoCo이다.

MoCo의 구조는 다음과 같다.

 

MoCo

1. Batch를 encoder에 넣어서 q를 만든다.

2. 또 다른 Batch를 momentum encoder에 넣어서 k를 만든다.

2-1. 각각의 encoder에서 나온 값을 batch normalizaiton을 한다.(이때 shuffling BN을 활용한다. 밑에서 자세히 언급.)

3. queue라는 하나의 memory bank를 생성한다.

    - 초깃값은 random하게 설정한다.

    - queue는 dictionary라고 하며, negative sample에 대한 사전 정의가 되어있는 embedding queue이다.

3. q와 k의 유사도, 그리고 q와 queue사이의 유사도를 계산하여서 각각 positive sample과 negative sample에 대한 결과값을 계산한다.

4. 그리고 InfoNCE(loss function)를 이용해서 값을 최종적으로 만들어낸다.

    - Softmax + Cross_Entropy 라고 생각하면 편하다.

    - 실제 코드에서도 Cross_Entropy를 이용한다.

 

Queue가 어떻게 변화하고, 왜 Momentum이라는 말이 붙었는지 이해가 아직 잘 안될텐데, 밑에서 자세히 살펴보자!


Result

+) 논문의 저자들은 사전에 정의되는 queue의 개수(negative sample의 개수)가 많으면 많을수록 성능이 좋다는 것을 증명했다.

+) 논문에서는 65536개의 dictionary를 정의한다.

+) a)는 end-to-end 방식으로, SimCLR과 같은 방식이라고 생각하면 된다.

+) b)는 memory bank를 이용하는 방식인데, 이 memory bank는 encoder에서 나온 q값을 이용해서 만들어지는데, 이럴 경우 memory bank에 있는 값들이 consistent하지 않아서 모델 학습에 어려움을 겪는다. (consistent한 문제를 momentum을 이용해서 해결하는데, 밑에서 자세히 언급하겠다.)


1. Shffling Batch Normalization

논문의 저자들은 2개의 batch가 각각 encoder와 momentum encoder로 들어가고 나서 나온 output에 대해서 Shuffling Batch Normalization을 실행한다.

 

이때 왜 Shuffling이라는 말이 붙었을까?

이유는 간단하다.

BN problem

만약 위의 사진 처럼 batch가 구성되어 있다면, 각 batch에 대해서 너무나 빠르게 문제를 풀어버리고 낮은 loss solution을 바로 찾아낼 수 있다.

즉, shorcut 학습이 될 수 있다는 것이다.

 

MoCo Batch Method
Shuffle BN

따라서 위에 처럼 각 GPU마다 이용되는 batch를 합친 후 섞어서 다시 divide한 batch를 Momentum encoder에 넣어줘서 학습을 진행하는 방식으로 위와 같은 문제를 해결하였다.

(SSL은 dataset이 크기 때문에, multi-GPU를 대부분 이용한다.)


2. Momentum Method

Momentum encoder
Update momentum encoder

Momentum encoder는 일단 기본적으로 gradient로 학습을 안하는 구조이다.

그래서 MoCo는 단방향 backpropagation 구조를 가지고 있는데, 여기서 Momentum encoder의 parameters를 momentum의 식으로 update를 진행한다.

 

논문에서는 m=0.999로 두고 설정하는데, 이럴 경우 기존의 k_parameters가 많이 반영되고, encoder에서 나온 q_parameters는 많이 반영이 안되므로, k값에 대해서 consistent를 유지하면서 갈 수 있다.


3. How to update dictionary(queue)

queue intialize

queue는 처음에 위와 같은 코드로, random한 값으로 시작된다.

보통 shape은 dim=128, K=65536으로 시작된다.

queue_ptr은 pointer의 개념으로, 쉽게 얘기하면 queue가 업데이트될 부분을 정해주는 역할이다.

dequeue and enqueue

Momentum encoder에서 나온 keys값으로 queue에 업데이트 해준다.

업데이트 하는 방식은 queue의 FIFO 방식을 따른다.


- 2022.12.31 Kyujinpy 작성.

반응형