본문 바로가기
Machine Learning/Meta Learning

[논문 리뷰] Prototypical Networks for Few-shot Learning

by SoongE 2021. 2. 7.

리뷰할 논문은 NIPS 2017에 소개된 Prototypical Networks for Few-shot Learning으로 Prototypical Networks를 이용하여 few-shot learning을 할 수 있는 모델에 대해 설명합니다.

 

Abstract

Train dataset에 있지 않은 새로운 class에 대해 학습할 때, 새로운 calss에 대한 dataset이 부족할 경우를 이를 대처하기 위한 방안으로 prototypical networks를 제안한다.

Prototypical networks는 각 class의 prototype representation까지의 거리를 계산해서 classification을 수행할 수 있는 metric space(거리 공간)을 학습한다. Few-shot learning에 대한 최근 연구들과 비교해보면 protonet(Prototypical networks를 줄여 말합니다)은 제한된 데이터 체계에서 유익하고 단순한 유도 편향을 반영하며 우수한 결과를 달성하였다. 

 

Introduction

prototypical networks의 접근법과 방식에 대해 설명한다.

Few-shot learning은 train dataset에는 없는 새로운 class에 대해 적은 데이테만을 가지고 있을 경우, 새로운 class를 수용하기 위해 분류기가 조정되어야 하는 작업이다. 가장 기본적인 접근법은 새로운 데이터를 가지고 re-training하는 것 이지만, overfit될 확률이 매우 높다. 문제가 매우 어려운 반면, 인간은 one-shot classification을 할 수 있는 능력이 있다. 그래서 최근 두 개의 연구는 'Vinyals et al.의 matching networks'와 'Ravi and Larochelle의 meta-LSTM'이 있다. 저자는 이 둘의 문제점인 overfitting을 지적하고 이를 줄이기 방향으로 protypical networks를 고안하였다.

Prototypical networks in the few-shot and zero-shot scenarios. C* are computed as the mean of embedded support examples for each class.

Protonet은 각 class에 대해 single prototype representation이 있는 embedding을 base로 접근하였다. 이를 위해 neural network를 사용하여 임베딩 공간에 대한 입력의 비선형 매핑을 학습하고, 임베딩 공간에서 설정된 support set의 평균으로 class의 prototype을 만든다. Zero-shot learning에서도 같은 방식의 접근법을 가진다. 따라서 protonet은 각 class의 prototype 역할을 하기 위해 meta-data를 공유 공간(shared space)에 임베딩하는 것을 학습한다.

Classification을 수행할 때 임베딩 된 query point에서 가장 가까운 class prototype을 찾는다. 각 class의 평균으로 prototype을 만들고 Euclidean distance를 이용해서 query point와의 거리를 계산한다. 이 거리 중 가장 가까운 prototype을 결정하고 query point의 class를 해당 prototype의 class로 예측한다.

 

What is prototype?

이 부분은 논문에는 있지 않은 내용이며, prototype이 무엇인지, protonet이 어떻게 동작하는지 이해를 돕기 위해 추가로 작성되었습니다.

K-means clustering

위 그림은 wiki pidia에 k-mean clustering을 검색하면 나오는 그림입니다. 위에서 보았던 그림과 비슷하죠? 결론부터 말하면 동그란 색깔점이 논문에서 말하던 prototype 혹은 prototype representation입니다. 각각의 네모점은 색깔점과 거리를 계산하여 가장 가까운 색깔점의 class를 따르게 됩니다.(두 번째 그림) 그 다음 각 class에서 거리의 평균을 내어 prototype의 위치를 update합니다.(세 번째 그림) 이런 식을 계속 해서 prototype을 만들어갑니다. 그럼 prototype이 뭔지 대충 감이 오셨으니 논문의 그림으로 다시 넘어가봅시다.

Protonet process

 

gif파일로 만들어본 protonet의 작동 process입니다. 3-way 5-shot의 classification에서 C1의 초록색 부분만 계산한다고 가정하면 X1~X5는 support set의 image tensor data입니다. 이 image tensor data를 모델(90도 회전한 파란색 마름모)에 넣으면 Z1~Z5(초록색 점)이 만들어집니다. 그리고 Z1~Z5를 모두 평균한 값이 C1(검은색 점)이 되고 이것이 하나의 class의 prototype이 됩니다! 이해가 좀 가시나요?? 이를 반복하면  C1, C2, C3 처럼 각 class의 prototype의 모두 구하게 되는 것이죠. 이제 query set의 data 하나를 가져와서 어떤 class인지 예측한다고 해봅시다. query set의 image tensor data는 Xq가 되고 Xq를 모델에 넣으면 Zq가 나옵니다. 이 Zq를 가지고 C1, C2, C3와 각각 Euclidean distance로 거리를 계산해줍니다. 계산한 값에 -를 붙여주게 되면 비로소 similarity가 되는 것이죠.(거리가 멀수록 즉, 값이 클수록 similarity는 낮은거니까 -를 붙여주면 값이 클수록 similarity가 커지겠죠?)

위 gif만 이해하시면 protonet의 전반적인 process를 모두 아는 겁니다😉😉

 

Prototypical Networks

- Notation 

Few-shot classification에서는 N개의 labeled된 support set example $S =\{(x_1, y_1), . . . ,(x_N , y_N)\}$가 존재하고 $x_i \in \mathbb{R}^D$는 D-dimensional feature vector of example이고, $y_i \in \{1,...,K\}$는 corresponding label이다. $S_k$는 class $k$에 대한 dataset을 말한다. 

 

- Model

Protonet은 M-dimensional representation인 $c_k \in \mathbb{R}^M$ or $prototype$을 계산하고 각각의 class는 embedding function $f_\phi : \mathbb{R}^D \rightarrow \mathbb{R}^M$을 거친다. $\phi$는 learnable parameter(weight)이다. 각각의 prototype은 각 class에 속한 mean vector of the embedded support points이다.

Class k에 대한 prototype 계산

Distance function $d = \mathbb{R}^M \times \mathbb{R}^M \rightarrow (0, +\infty)$로 protonet은 embedding space에서의 prototype에 대한 distribution을 생성해 내는데 이 distribution은 distance로 softmax한 query point x의 class를 결정할 때 필요하다.

말이 어려운데, protonet은 embedding space에서 각 class를 대표하는 prototype의 분포를 생성하고(위 그림에서 C1,C2 등등) query dataset 중 하나인 query point x를 distance 기반의 softmax를 취한 값을 비교할 때 사용한다. 

Probability about label y and data x in class k

Negative log-probability $J(\phi) = - logp_\phi(y=k|x)$를 최소화 하기 위해서 SGD를 이용하고 Training episode는 training set에서 랜덤하게 class를 선택하여 만든다. 그리고 남은 것 중 일부를 선택하여 query point를 만든다. 

 

아래의 그림은 training episode에서 loss $J(\phi)$를 계산하기 위한 pseudocode인데 gif로 조금 더 알기 쉽게 만들었습니다.

Overview of algorithm

- Prototypical Networks as Mixture Density Estimation

Regular Bregman divergences로 알려져 있는 distance function에 대해서 protonet algorithm은 support set에 대해 performing mixture density estiation을 적용한다. Regular Bregman divergence $d$는 아래 식으로 정의된다.

Bregman divergences는 squared Euclidean distance $|| z-z' ||^2$ 또한 포함한다. 그리고 protonet은 거리를 계산하기 위해 squared Euclidean distance를 사용한다.(즉, 거리 계산으로 유클리디안 거리의 제곱을 사용한다는 소리를 어렵게 해놓은 것 같다.)

 

- Design Choices

실험 결과 train의 $N_c$(number of class)는 test의 $N_c$보다 크게 잡는 것이 더 좋은 결과를 내었고, train과 test의 $N_s$(number of shot per class)는 같게 하는 것이 대부분 좋은 결과를 내었다.

Few-shot classification accuracies on Omniglot.

 

Experiment: Omniglot Few-shot Classification

Omniglot dataset은 handwritten character로 50개의 alphabet이 존재하고 각 alphabet의 character들은 20개의 example이 존재한다. Vinyals et al.의 procedure를 따랐고 28 x 28 사이즈로 resize 하였으며, 90의 배수만큼 rotation을 통해 augmenting을 진행하였다. Embedding architecture은 4개의 convolution block을 사용하였고, 각 block은 64 filter의 3x3 convolution, batch normalization, ReLU, 2x2 max-polling으로 구성되어 있다. Initial learning rate는 $10^(-3)$이고 2000episode마다 절반으로 learning rate를 줄였다. 1-shot, 5-shot scenarios는 train을 할 때 60개의 class와 5개의 query point를 사용한다. 

 

Experiment: miniImageNet Few-shot Classification

miniImageNet dataset역시 Vinyals et al.의 split을 사용하였고 84x84 사이즈의 60,000개 color 이미지로 구성되어 있고 class당 600개씩 총 100개의 class가 존재한다. Embedding architecture와 learning rate는 Omniglot dataset을 훈련할 때와 같았고, train dataset은 1-shot일때는 30-way, 5-shot일때는 20-way로 구성하였다.

Few-shot classification accuracies on miniImageNet

Conclusion

신경망에 의해 학습된 representation space에서 각 class를 example을 사용하여 나타낼 수 있다는 생각에 기반하여 few-shot learning을 위한 prototype networks라는 간단한 방법을 제안했다. 이러한 networks가 episodeic training을 사용하여 few-shot에 잘 작동하도록 훈련한다. 이러한 접근 방식은 최근 meta learning의 접근법보다 훨씬 단순하고 효율적이며, 데이터에 따른 정교한 확장 없이도 SOTA를 달성하였다. 

 

Reference

https://arxiv.org/abs/1703.05175

https://www.youtube.com/watch?v=rHGPfl0pvLY

 

댓글