基础积累:图像分割损失函数最全面、最详细总结,含代码

作者:sfxiang
首发:ai算法修炼营
这是一篇关于图像分割损失函数的总结,具体包括:
binary cross entropy
weighted cross entropy
balanced cross entropy
dice loss
focal loss
tversky loss
focal tversky loss
log-cosh dice loss (本文提出的新损失函数)
代码地址:https://github.com/shruti-jadon/semantic-segmentation-loss-functions
项目推荐:https://github.com/junma11/segloss
图像分割一直是一个活跃的研究领域,因为它有可能修复医疗领域的漏洞,并帮助大众。在过去的5年里,各种论文提出了不同的目标损失函数,用于不同的情况下,如偏差数据,稀疏分割等。在本文中,总结了大多数广泛用于图像分割的损失函数,并列出了它们可以帮助模型更快速、更好的收敛模型的情况。此外,本文还介绍了一种新的log-cosh dice损失函数,并将其在nbfs skull-stripping数据集上与广泛使用的损失函数进行了性能比较。某些损失函数在所有数据集上都表现良好,在未知分布数据集上可以作为一个很好的选择。
简介
深度学习彻底改变了从软件到制造业的各个行业。深度学习在医学界的应用也十分广泛,例如使用u-net进行肿瘤分割、使用segnet进行癌症检测等。在这些应用中,图像分割是至关重要的,分割后的图像除了告诉我们存在某种疾病外,还展示了它到底存在于何处,这为实现自动检测ct扫描中的病变等功能提供基础保障。
图像分割可以定义为像素级别的分类任务。图像由各种像素组成,这些像素组合在一起定义了图像中的不同元素,因此将这些像素分类为一类元素的方法称为语义图像分割。在设计基于复杂图像分割的深度学习架构时,通常会遇到了一个至关重要的选择,即选择哪个损失/目标函数,因为它们会激发算法的学习过程。损失函数的选择对于任何架构学习正确的目标都是至关重要的,因此自2012年以来,各种研究人员开始设计针对特定领域的损失函数,以为其数据集获得更好的结果。
在本文中,总结了15种基于图像分割的损失函数。被证明可以在不同领域提供最新技术成果。这些损失函数可大致分为4类:基于分布的损失函数,基于区域的损失函数,基于边界的损失函数和基于复合的损失函数( distribution-based,region-based,  boundary-based,  and  compounded)。
本文还讨论了确定哪种目标/损失函数在场景中可能有用的条件。除此之外,还提出了一种新的log-cosh dice损失函数用于图像语义分割。为了展示其效率,还比较了nbfs头骨剥离数据集上所有损失函数的性能。
distribution-based loss
1.  binary cross-entropy:二进制交叉熵损失函数
交叉熵定义为对给定随机变量或事件集的两个概率分布之间的差异的度量。它被广泛用于分类任务,并且由于分割是像素级分类,因此效果很好。在多分类任务中,经常采用 softmax 激活函数+交叉熵损失函数,因为交叉熵描述了两个概率分布的差异,然而神经网络输出的是向量,并不是概率分布的形式。所以需要 softmax激活函数将一个向量进行“归一化”成概率分布的形式,再采用交叉熵损失函数计算 loss。
交叉熵损失函数可以用在大多数语义分割场景中,但它有一个明显的缺点:当图像分割任务只需要分割前景和背景两种情况。当前景像素的数量远远小于背景像素的数量时,即的数量远大于的数量,损失函数中的成分就会占据主导,使得模型严重偏向背景,导致效果不好。
#二值交叉熵,这里输入要经过sigmoid处理 import torch import torch.nn as nn import torch.nn.functional as f nn.bceloss(f.sigmoid(input), target) #多分类交叉熵, 用这个 loss 前面不需要加 softmax 层 nn.crossentropyloss(input, target)  
2、weighted binary cross-entropy加权交叉熵损失函数
class weightedcrossentropyloss(torch.nn.crossentropyloss):        network has to have no nonlinearity!        def __init__(self, weight=none):        super(weightedcrossentropyloss, self).__init__()        self.weight = weight    def forward(self, inp, target):        target = target.long()        num_classes = inp.size()[1]        i0 = 1        i1 = 2        while i1  0 reduces the relative loss for well-classified examples (p>0.5) putting more                    focus on hard misclassified example    :param smooth: (float,double) smooth value when cross entropy    :param balance_index: (int) balance class index, should be specific when alpha is float    :param size_average: (bool, optional) by default, the losses are averaged over each loss element in the batch.        def __init__(self, apply_nonlin=none, alpha=none, gamma=2, balance_index=0, smooth=1e-5, size_average=true):        super(focalloss, self).__init__()        self.apply_nonlin = apply_nonlin        self.alpha = alpha        self.gamma = gamma        self.balance_index = balance_index        self.smooth = smooth        self.size_average = size_average        if self.smooth is not none:            if self.smooth  1.0:                raise valueerror('smooth value should be in [0,1]')    def forward(self, logit, target):        if self.apply_nonlin is not none:            logit = self.apply_nonlin(logit)        num_class = logit.shape[1]        if logit.dim() > 2:            # n,c,d1,d2 -> n,c,m (m=d1*d2*...)            logit = logit.view(logit.size(0), logit.size(1), -1)            logit = logit.permute(0, 2, 1).contiguous()            logit = logit.view(-1, logit.size(-1))        target = torch.squeeze(target, 1)        target = target.view(-1, 1)        # print(logit.shape, target.shape)        #         alpha = self.alpha        if alpha is none:            alpha = torch.ones(num_class, 1)        elif isinstance(alpha, (list, np.ndarray)):            assert len(alpha) == num_class            alpha = torch.floattensor(alpha).view(num_class, 1)            alpha = alpha / alpha.sum()        elif isinstance(alpha, float):            alpha = torch.ones(num_class, 1)            alpha = alpha * (1 - self.alpha)            alpha[self.balance_index] = self.alpha        else:            raise typeerror('not support alpha type')                if alpha.device != logit.device:            alpha = alpha.to(logit.device)        idx = target.cpu().long()        one_hot_key = torch.floattensor(target.size(0), num_class).zero_()        one_hot_key = one_hot_key.scatter_(1, idx, 1)        if one_hot_key.device != logit.device:            one_hot_key = one_hot_key.to(logit.device)        if self.smooth:            one_hot_key = torch.clamp(                one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)        pt = (one_hot_key * logit).sum(1) + self.smooth        logpt = pt.log()        gamma = self.gamma        alpha = alpha[idx]        alpha = torch.squeeze(alpha)        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt        if self.size_average:            loss = loss.mean()        else:            loss = loss.sum()        return loss  
5、distance map derived loss penalty term距离图得出的损失惩罚项
可以将距离图定义为ground truth与预测图之间的距离(欧几里得距离、绝对距离等)。合并映射的方法有2种,一种是创建神经网络架构,在该算法中有一个用于分割的重建head,或者将其引入损失函数。遵循相同的理论,可以从gt mask得出的距离图,并创建了一个基于惩罚的自定义损失函数。使用这种方法,可以很容易地将网络引导到难以分割的边界区域。损失函数定义为:
class dispenalizedce(torch.nn.module):        only for binary 3d segmentation    network has to have no nonlinearity!        def forward(self, inp, target):        # print(inp.shape, target.shape) # (batch, 2, xyz), (batch, 2, xyz)        # compute distance map of ground truth        with torch.no_grad():            dist = compute_edts_forpenalizedloss(target.cpu().numpy()>0.5) + 1.0                dist = torch.from_numpy(dist)        if dist.device != inp.device:            dist = dist.to(inp.device).type(torch.float32)        dist = dist.view(-1,)        target = target.long()        num_classes = inp.size()[1]        i0 = 1        i1 = 2        while i1 0.5) + 1.0        # print('dist.shape: ', dist.shape)        dist = torch.from_numpy(dist)        if dist.device != net_output.device:            dist = dist.to(net_output.device).type(torch.float32)                tp = net_output * y_onehot        tp = torch.sum(tp[:,1,...] * dist, (1,2,3))                dc = (2 * tp + self.smooth) / (torch.sum(net_output[:,1,...], (1,2,3)) + torch.sum(y_onehot[:,1,...], (1,2,3)) + self.smooth)        dc = dc.mean()        return -dc  
2、hausdorff distance loss
hausdorff distance loss(hd)是分割方法用来跟踪模型性能的度量。
任何分割模型的目的都是为了最大化hausdorff距离,但是由于其非凸性,因此并未广泛用作损失函数。有研究者提出了基于hausdorff距离的损失函数的3个变量,它们都结合了度量用例,并确保损失函数易于处理。
class hddtbinaryloss(nn.module):    def __init__(self):                compute haudorff loss for binary segmentation        https://arxiv.org/pdf/1904.10030v1.pdf                        super(hddtbinaryloss, self).__init__()    def forward(self, net_output, target):                net_output: (batch_size, 2, x,y,z)        target: ground truth, shape: (batch_size, 1, x,y,z)                net_output = softmax_helper(net_output)        pc = net_output[:, 1, ...].type(torch.float32)        gt = target[:,0, ...].type(torch.float32)        with torch.no_grad():            pc_dist = compute_edts_forhdloss(pc.cpu().numpy()>0.5)            gt_dist = compute_edts_forhdloss(gt.cpu().numpy()>0.5)        # print('pc_dist.shape: ', pc_dist.shape)                pred_error = (gt - pc)**2        dist = pc_dist**2 + gt_dist**2 # /alpha=2 in eq(8)        dist = torch.from_numpy(dist)        if dist.device != pred_error.device:            dist = dist.to(pred_error.device).type(torch.float32)        multipled = torch.einsum(bxyz,bxyz->bxyz, pred_error, dist)        hd_loss = multipled.mean()        return hd_loss  
compounded loss
1、exponential logarithmic loss
指数对数损失函数集中于使用骰子损失和交叉熵损失的组合公式来预测不那么精确的结构。对骰子损失和熵损失进行指数和对数转换,以合并更精细的分割边界和准确的数据分布的好处。它定义为:
2、combo loss
组合损失定义为dice loss和修正的交叉熵的加权和。它试图利用dice损失解决类不平衡问题的灵活性,同时使用交叉熵进行曲线平滑。定义为:(dl指dice loss)


性能升级:华为新款MateBook X Pro极速上手体验
Aquanaut:结合ROV和AUV特性的水下机器人,目前处于设计阶段
究竟是什么影响电池储能大规模应用?
RISC和CISC架构有什么区别
宜家智能灯泡兼容苹果HomeKit
基础积累:图像分割损失函数最全面、最详细总结,含代码
热电冷却器控制-Thermoelectric Cooler
不需电源的短途电话,Home phone
Python爬虫绕过登录的小技巧
52070-007P 0.9mm SuperMini 插头(公头)直接焊接SOUTHWEST
N32L40XCL-STB开发板评测报告
一文带你从功率MOS入门到精通!
马斯克大脑芯片植入引发伦理担忧
电池材料量价齐飞,上市公司迎来业绩高光
什么是电抗器
怎么防止别人抄自己的电路板
Facebook布局虚拟现实(VR)和增强现实(AR)的原因是什么?
一些典型的电源测序应用,让你少走弯路
区块链市场的5个最新发展动态介绍
高通与苹果持续了近两年的纠纷近期有望和解