github

https://github.com/chaeyeongyoon/Classification_compare

각 모델 구현 코드 설명 - models.py

ResNet

VGG

Train code

일단 필요한 라이브러리들을 import 해줍니다

import argparse
import sys
import time
import os
import datetime
from unittest import result
from models.resnet import Resnet
from models.vgg import VGG
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

accuracy 계산에 쓰이는 accuracy 함수

def accuracy(pred_tensor, y_tensor):
    # accuracy = torch.eq(pred_tensor, y_tensor).sum().item() / len(pred_tensor)
    accuracy = torch.sum(pred_tensor == y_tensor) / len(pred_tensor)
    return accuracy

처음에 torch.eq로 짰다가 torch.sum으로 하는게 더 빠르다고 생각해 (알고리즘 상으로) 수정했습니다.

pred tensor==y_tensor는 같은 위치의 요소를 비교해 요소의 값이 같은 위치는 1, 값이 다른 위치는 0인 리스트를 반환합니다.

거기에 torch.sum을 해주면 같은 개수가 나오게 되고 그 값을 텐서의 길이, 즉 요소의 개수로 나눠주면 전체에 비해 맞은 비율 == accuracy가 나옵니다

test set으로 loss 및 acc 를 계산하는 evaluate함수

def evaluate(model, testloader, device):
    model.eval()
    test_acc = 0
    test_loss = 0
    # Starts batch test
    for x_batch, y_batch in testloader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        pred = model(x_batch)
        pred = F.softmax(pred, dim=1)
        loss = nn.CrossEntropyLoss()
        loss_output=loss(pred, y_batch).item()
        test_loss += loss_output
        # acc = accuracy(pred, y_batch)
        prediction = torch.argmax(pred, dim=1)
        acc = accuracy(prediction, y_batch)
        test_acc += acc

    test_acc = test_acc/ len(testloader)
    test_loss = test_loss / len(testloader)

    return test_loss, test_acc

train함수에서 만드는 model, testlaoder와 device를 넘겨주면 testloader의 데이터들로 acc와 loss를 계산합니다.

일단 평가기 때문에 model은 evaluation 모드로 바꿔주고 (model.eval())

x_batch, y_batch불러와 device에 넣고

model(x_batch) 로 예측을 수행한 후

F.softmax를 통해 softmax를 취한 값을 y_batch와 함께 loss 함수에 넣어주어 loss 를 구해줍니다.