BB119A11-2F0F-47B7-A665-F40C04B4A256.jpeg

2A212107-19EE-4E30-8347-F2F10C00CFEF.jpeg

multiclass 대한 Dice coefficient 구현하기

배치 단위 이미지가 들어가면, 클래스별로 dice coefficient의 배치당 평균들의 클래스별 평균을 리턴하는 함수를 구현하고자 한다.

ex) batch size=2, num_classes = 3인 경우

클래스 0에 대한 dice coefficient의 평균

클래스 1에 대한 dice coefficient의 평균

클래스 2에 대한 dice coefficient의 평균 ⇒ 셋의 합 / 3 ( 클래스 수로 평균)

def dice_coefficient(pred, target, num_classes, ignore_idx=None):
    '''
    softmax_pred: (N, C, H, W), ndarray
    target : (N, H, W), ndarray
    '''
    assert pred.shape[0] == target.shape[0]
    
    if num_classes == 2:
        epsilon = 1e-6
        # pred = np.around(pred) # 0.0 of 1.0
        dice = 0
        # if both a and b are 1-D arrays, it is inner product of vectors(without complex conjugation)
        for batch in range(pred.shape[0]):
            inter = np.dot(pred[batch].reshape((-1,)), target[batch].reshape((-1,)))
            sum_sets = np.sum(pred[batch]) + np.sum(target[batch])
            dice += (2*inter+epsilon) / (sum_sets + epsilon)
        return dice / pred.shape[0]
        
    else:
        softmax = nn.Softmax(dim=1)
        pred = softmax(torch.from_numpy(pred).type(torch.float64))
        pred = np.array(pred)
        dice = 0
        for c in range(num_classes):
            if c==ignore_idx:
                continue
            dice += dice_coefficient(pred[:, c, :, :], np.where(target==c, 1, 0), 2, ignore_idx)
        return dice / num_classes

Dice Loss

$$ DiceLoss = 1-DiceCoefficient $$

$$ 0\leq DiceCoefficient \leq 1 $$

보통 클래스 개수가 1개일 때, binary classification일 때 사용을 많이 하며 이 때는 sigmoid를, multi class 대해서 사용하려면 softmax를 사용해야 한다

E1B72B52-846F-41AD-939D-2F2FE7EA17FC.jpeg

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

def dice_coefficient(pred:Tensor, target:Tensor, num_classes:int, ignore_idx=None):
    assert pred.shape[0] == target.shape[0]
    epsilon = 1e-6
    if num_classes == 2:
        dice = 0
        # if both a and b are 1-D arrays, it is inner product of vectors(without complex conjugation)
        for batch in range(pred.shape[0]):
            pred_1d = pred[batch].view(-1)
            target_1d = target[batch].view(-1)
            inter = (pred_1d * target_1d).sum()
            sum_sets = pred_1d.sum() + target_1d.sum()
            dice += (2*inter+epsilon) / (sum_sets + epsilon)
        return dice / pred.shape[0]
        
    
    elif num_classes == 1:
        dice = 0
        pred = F.Sigmoid(pred)
        for batch in range(pred.shape[0]):
            pred_1d = pred[batch].view(-1)
            target_1d = target[batch].view(-1)
            inter = (pred_1d * target_1d).sum()
            sum_sets = pred_1d.sum() + target_1d.sum()
            dice += (2*inter+epsilon) / (sum_sets + epsilon)
        
    else:
        pred = F.softmax(pred, dim=1).float()
        dice = 0
        for c in range(num_classes):
            if c==ignore_idx:
                continue
            dice += dice_coefficient(pred[:, c, :, :], torch.where(target==c, 1, 0), 2, ignore_idx)
        return dice / num_classes 

def dice_loss(pred, target, num_classes, ignore_idx=None):
    dice = dice_coefficient(pred, target, num_classes, ignore_idx)
    return 1 - dice

foreground에만 집중하는 loss라고 생각할 수 있는데, 응용하면 background까지도 고려하는 loss를 만들 수 있다.

$$ \frac{2(|A\cap B|+|(1-A)\cap (1-B)|)}{|A|+|B|+(2-|A|-|B|)} $$

Compound Loss = CrossEntropy Loss + Dice Loss