이번 포스트에서는 Prototypical networks(protonet)를 pytorch 코드로 구현해 보려고 합니다. 논문을 읽다 보면 experiment와 networks architecture에 대해 친절한 논문도 있고 그렇지 않은 논문도 있어서 혹시 잘못된 것이 발견된다면 댓글로 알려주세요! 질문 또한 환영입니다.
전체 코드는 github를 봐주세요!
Dataset 만들기
Protonet을 실험하기 위해 사용한 dataset은 omniglot dataset과 miniImageNet dataset입니다. Omniglot dataset에 대해서는 이전 글에서 설명했으니 넘어갑니다. MiniImageNet dataset은 ImageNet dataset의 축소 버전으로 100개의 class당 600개의 이미지로 총 60,000개의 data가 존재합니다. 앞으로 설명할 내용들은 miniImageNet dataset을 기준으로 experiment 할 때에 대해 설명합니다.
Protonet은 train 할 때 prototype을 위한 image와 test를 위한 image가 각각 필요합니다. 이 prototype을 위한 image를 support set이라 부르고, test를 위한 image를 query set 혹은 query point라고 부릅니다. 논문에서는 train할 때의 way를 test 할 때 way 보다 크게 설정하여 사용하는 것이 더 좋은 결과를 낼 수 있다고 합니다. 따라서 1-shot classification에서는 30-way의 episode를 사용하고, 5-shot classification에서는 20-way episode를 사용합니다. Query point는 모두 episode당 15개를, support set은 train과 test를 맞추는 것이 정확도가 높게 나왔다고 합니다.
즉, 5-shot classification을 할 경우 train을 위해 20-way 5-shot 15 query point로 data loader를 구성하고 test는 5-way 5-shot 15-query point로 data loader를 구성합니다. Train dataset의 dataloader만 생각했을 때 하나의 episode당 총 400장의 이미지를 선택해야 합니다. 따라서 한 번의 episode마다 20개의 class를 선택하고 각 class에서 랜덤 하게 20장의 이미지를 선택하여 support set과 query set으로 나뉘어야 합니다. Train, Validation, Test split은 Vinyals et al. 의 방식을 사용하였습니다.
Dataloader를 만들기 위해 batch sampler를 사용하였고, 사용법은 yield batch의 batch에 선택할 index array를 넣어준다고 간단하게 생각하면 됩니다.
class PrototypicalBatchSampler(Sampler):
def __init__(self, ...):
self.idxs = range(len(self.labels))
self.indexes = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan
self.indexes = torch.Tensor(self.indexes)
self.num_per_class = torch.zeros_like(self.classes)
for idx, label in enumerate(self.labels):
label_idx = np.argwhere(self.classes == label).item()
self.indexes[label_idx, np.where(np.isnan(self.indexes[label_idx]))[0][0]] = idx
self.num_per_class[label_idx] += 1
def __iter__(self):
nss = self.num_samples_support
nsq = self.num_samples_query
cpi = self.classes_per_it
for _ in range(self.iterations):
batch_s = torch.LongTensor(nss * cpi)
batch_q = torch.LongTensor(nsq * cpi)
c_idxs = torch.randperm(len(self.classes))[:cpi] # 랜덤으로 클래스 way개 선택
for i, c in enumerate(self.classes[c_idxs]):
s_s = slice(i * nss, (i + 1) * nss) # 하나의 클래스당 선택한 support 이미지
s_q = slice(i * nsq, (i + 1) * nsq) # 하나의 클래스당 선택한 query 이미지
label_idx = torch.arange(len(self.classes)).long()[self.classes == c].item()
sample_idxs = torch.randperm(self.num_per_class[label_idx])[:nss + nsq]
batch_s[s_s] = self.indexes[label_idx][sample_idxs][:nss]
batch_q[s_q] = self.indexes[label_idx][sample_idxs][nss:]
batch = torch.cat((batch_s, batch_q))
yield batch
def __len__(self):
return self.iterations
# 이 코드는 간소화된 코드. 정확한 코드는 Github 참조.
먼저 전체 데이터에 대해 class별로 index를 구하고 각 class당 몇 개의 이미지가 있는지 저장하는 list가 필요합니다. 따라서 __init__에서 indexes와 num_per_class를 설정해줍니다. indexes는 [class_num][element]와 같이 2차원으로 되어있고, num_per_class는 [class_num]으로 1차원의 list입니다. miniImagenet dataset의 경우 각 class당 element의 개수가 600개로 일정하지만, 다른 custom data를 위해서 num_per_class를 구하였습니다. 그런 다음 __iter__에서 batch마다 구성할 data의 index를 설정합니다. 랜덤으로 n-way개의 class를 선택하고(c_idxs) support set의 개수와 query set의 개수만큼 index를 랜덤 하게 선택하여 batch_s, batch_q에 넣어주고 두 개를 concatenate 하여 yield로 batch를 넘겨줍니다.
Model 만들기
protonet의 모델은 siamese net과 비슷한데 단순히 4개의 convolution layer를 이어 붙였습니다. 아래의 코드에서는 ConvBlock으로 선언되어 있고 ConvBlock은 torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU, torch.nn.MaxPool2d를 차례대로 거치게 구성되어 있습니다.
Loss function 만들기
기존의 ML에서는 loss를 구하기 위해 pytorch의 loss function에 model의 output값과 label 값을 넣어주면 되지만 protonet에서는 output값을 바로 loss function에 넣을 수 없습니다. Protonet은 단순히 이미지들을 feature vector로 만들어주는 역할만 하기 때문입니다. Data loader에서 x는 support set과 query set을 붙여놓았기 때문에 먼저 support set과 query set을 나눠야 합니다. 아래 함수에서 input은 model에서 넣어서 나온 output, target은 label입니다.
support set과 query set으로 나누기 위해 n_support도 함께 parameter로 받습니다. support idxs는 class별로 index를 정리한 list입니다. 이를 가지고 각 class의 feature를 평균을 내어 prototype을 만듭니다. query samples는 마찬가지로 calss별로 index를 만들고 평균을 내지 않은 채로 남깁니다. 이 둘의 euclidean distance를 계산하고 log softmax를 취해 y_hat을 구할 수 있습니다. 위의 과정을 거친 이유는 각 batch마다 class의 label은 랜덤 하지만, softmax의 argmax값은 항상 0~19로 일정하기 때문입니다. 따라서 target_label은 0부터 19까지 quert의 개수만큼 만들어주면 됩니다.
이 값들을 통해 loss와 accuracy를 계산할 수 있습니다.
Result
추후 업데이트 예정
'Machine Learning > Meta Learning' 카테고리의 다른 글
[논문 리뷰] Prototypical Networks for Few-shot Learning (1) | 2021.02.07 |
---|---|
[논문 코딩] Siamese Neural Networks for One-shot Image Recognition - Pytorch (1) | 2021.02.07 |
[논문 리뷰] Siamese Neural Networks for One-shot Image Recognition (6) | 2021.01.13 |
Meta Learning (1) | 2021.01.13 |
댓글