본문 바로가기
Machine Learning/Meta Learning

[논문 코딩] Siamese Neural Networks for One-shot Image Recognition - Pytorch

by SoongE 2021. 2. 7.
 

[논문 리뷰] Siamese Neural Networks for One-shot Image Recognition

Meta Learning 학습하는 과정을 학습하다. Meta learning은 현재 AI에서 가장 유망하고 트렌디한 연구분야로 AGI(Artificial General Intelligence)로 나아갈 수 있는 매우 중요한 디딤돌이라고 볼 수 있다. AGI란..

rhcsky.tistory.com

이전에 리뷰하였던 Siamese Neural Networks for One-shot Image Recognition에 나왔던 siamese networks를 직접 코드로 구현해 보려고 합니다. 이것저것 논문을 읽다 보면 experiment와 networks architecture에 대해 친절한 논문도 있고 그렇지 않은 논문도 있어서 혹시 잘못된 것이 발견된다면 댓글로 알려주세요! 질문 또한 환영입니다😊.

전체 코드는 github를 봐주세요!

 

Dataset 만들기

Omniglot dataset

논문에서는 여러개의 데이터셋이 있지만 이번 글에서는 Omniglot dataset을 이용합니다. Omniglot dataset은 50개의 alphabet이 존재하고 각각의 alphabet은 15~40개의 character를 가지고 있어서 총 1623개의 class가 존재하는 데이터셋 입니다(ㅚ, ㅃ 같은 한국어도 보여서 너무 좋네요). dataset을 어떻게 구성하였는지는 이전 글을 참고하고, 바로 코드로 작성해봅니다. 일단 Omniglot dataset을 다운로드하면 images_backgroundimages_evaluation을 받을 수 있는데 각각 40개, 10개의 class를 가지고 있습니다. images_background에서 다시 30개의 train set과 10개의 validation set으로 나누어서 저장한 다음 사용해야 하므로 아래 코드와 같이 train data와 validation data를 랜덤으로 선택해 분리해줍니다. 

Split background data into train, validation data

back_alpha에는 40개의 class가 들어있고 그곳에서 랜덤으로 30개를 골라 train_alpha에 넣어주고, val_alpha는 back_alpha중 train_alpha에 없는 class를 골라 넣어줍니다. test_alpha는 제공되는 데이터 그대로 사용합니다.

 

이렇게 정리한 데이터로 dataset class를 새로 선언해줍니다. Torchvision의 dataset class를 상속받는 Omniglot train dataset class를 선언하고 그중 get item 부분만 일부 가져와 봤습니다. 

OmniglotTrainDataset class

50% 비율로 same, different를 만들어야 해서 index가 짝수, 홀수 일 때로 나누어 코드를 작성했는데, index가 홀수일 때는 서로 같은 class에서 이미지는 가져오고 index가 짝수일 때는 서로 다른 class에서 이미지를 가져올 수 있도록 하였습니다. 사용되는 self.dataset은 이전에 만들어 놓았던 train 폴더입니다.

 

validation, test dataset은 train dataset과 조금 다르게 구성되어 있는데 index를 짝수, 홀수로 나누는 것이 아니라 way에 따라서 나눕니다. 만약 20 way라면 0번째 이미지만 기준이 되는 이미지와 같은 class로 만들고 나머지 1~19번째 이미지는 기준 이미지와 다른 class의 이미지로 전달해야 하기 때문입니다.

OmniglotTestDataset class

 

Data loader는 pytorch의 DataLoader를 그대로 사용하고 batch_size=way(n-way k-shot의 way)로 지정하는 것만 조심하면 됩니다.

 

 

Model 만들기

Siamese network의 구조는 간단합니다. 4번의 conv layer(self.conv)를 거치고 flatten layer(self.linear)와 feature vector의 similarity를 만들어주는 layer(self.out) 만들어주면 되는데 sub_forward가 존재하는 게 차이점입니다. 보통은 forward 한 번으로 끝나지만 siamese networks는 두 개의 이미지를 weight을 공유하는 module에 넣고 여기서 나온 feature vector를 가지고 다시 학습을 하기 때문에 두 개의 이미지를 encoding 하는 sub_forward가 필요합니다.

Siamese networks model - init

sub_forward는 convolution block(self.conv)과 sigmoid를 거치게 됩니다.

Siamese networks model - sub forward

forward에서는 두 개의 이미지를 각각 sub forward를 거치게 하여 feature vector를 만들고 둘의 차이를 l1 distance로 계산합니다. 그리고 그 계산 값을 다시 linear layer를 거치게 하는데 원래는 마지막에 sigmoid를 통과시켜서 결과를 내야 하지만 pytorch에는 BCEWithLogitsLoss라는 loss fuction을 활용하기 위해 마지막은 sigmoid 연산을 해주지 않습니다. (BCEWithLogitsLoss는 자동으로 sigmoid를 하기 때문에 모델의 마지막에 sigmoid를 해주면 안돼요!)

Siamese networks model - forward

 

학습하기

이제 model과 data loader를 모두 만들었으니 학습을 진행합니다. Optimizer는 SGD, learning rate는 1e-1~1e-4까지 매 epoch마다 1%씩 감소하도록 scheduler를 만들었다고 합니다. Loss function은 BCEWithLogitsLoss를 이용합니다.

model = SiameseNet()
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer=optimizer,
			lr_lambda=lambda epoch: 0.99 ** epoch)

criterion = torch.nn.BCEWithLogitsLoss()

for epoch in epochs:
    model.train()
    for i, (x1,x2,y) in train_loader:
        out = model(x1, x2)
        loss = criterion(out, y.unsqueeze(1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()

    model.eval()
    with torch.no_grad():
        for i, (x1,x2,y) in valid_loader:
            out = model(x1, x2)
            
            y_pred = torch.sigmoid(out)
            y_pred = torch.argmax(y_pred)

            if y_pred == 0:
                correct += 1
                

train부분은 일반적인 것과 같으니 그냥 넘어가고 evaluation부분에서 y_pred == 0일 때만 correct 점수를 올려줬는데 y_pred가 0일때만 correct가 올라가는 이유가 무엇일까요? 저희가 TestDataset Class를 만들때 index%way == 0일때만 same class의 이미지를 넣어줬기 때문입니다. N-way가 20이라면 20개의 이미지 중 0번째 이미지만 기준이 되는 이미지와 같은 class이고 나머지 이미지는 다른 class의 이미지이기 때문에 y_pred 가 0 일때 답을 맞혔다고 할 수 있습니다.

 

이렇게 omniglot dataset을 이용한 siamese networks 코딩을 마치겠습니다. 위 코드는 간략화 한 코드이기 때문에 자세한 코드와 실행 방법은 github를 참고해주세요😃

https://github.com/Rhcsky/MetaLearning-pytorch/tree/main/siamese

 

Rhcsky/MetaLearning-pytorch

Meta learning algorithm using pytorch. Contribute to Rhcsky/MetaLearning-pytorch development by creating an account on GitHub.

github.com

 

댓글