图像分割的损失函数
基于分布相似度的损失
在此类损失函数中,主要使用信息论中的交叉熵机制来度量模型输出和真实目标之间所包含的信息的相似性。
交叉熵损失
假设$P(Y=0)=p$,$P(Y=1)=1-p$。预测值由logistic
/sigmoid
函数计算得出:
$P(\hat Y=0)= \frac{1}{1+e^{-z}}$和$P(\hat Y=1)=1- \frac{1}{1+e^{-z}}=1-\hat p$
交叉熵损失函数的定义形式如下:
上述交叉熵损失为二维交叉熵损失。
加权交叉熵损失
加权交叉熵损失(weighted corss entropy, WCE)是交叉熵损失的一种变体,其中,所有的正样本都被乘以一个系数以进行加权。该损失函数常用于类别不平衡问题中。例如,当你有一张有10%的黑像素和90%的白像素的图片时,常规的CE效果不会太好。
WCE定义如下式:
减少假负样本的数目,将$\beta$设置为大于1.增加假正样本的数目,将$\beta$设置为小于1。
weighted bce代码实现:
1 | def criterion_pixel(logit_pixel, truth_pixel): |
均衡交叉熵损失
均衡交叉熵损失(balanced corss entropy, BCE)和WCE类似,唯一的不同之处在于也对负样本进行了加权。
如下式所示:
Focal loss
在机器学习任务中,除了会遇到严重的类别样本数不均衡问题之外,经常也会遇到容易识别的样本数目和难识别的样本数目不均衡的问题。为了解决这一问题,何凯明大神提出了Focal loss。
Focal loss尝试降低easy example对损失的贡献,这样网络会集中注意力在难样本上。
FL定义如下:
上述公式为二分类问题的Focal loss,可以看出对于每一个样本,使用$(1-\hat p)^\gamma$作为其识别难易程度的指标,预测值$\hat p$越大代表对其进行预测越容易,因而其在总体损失中的占比应该越小。
对于多分类问题,其形式为:
对于每一个样本,$p_t$为模型预测出其属于其真实类别的概率,$\alpha_t$可用于调节不同类别之间的权重。将$\lambda$设置为0便可以得到BCE。
下述代码既可以计算二分类问题,也可以计算多分类问题的focal loss。
1 | class RobustFocalLoss2d(nn.Module): |
与最近的cell的距离
有文章在交叉熵损失的基础上加入了一个距离函数,是的网络学习两个结束目标的分离边界。如下所示:
其中$d_1(x)$和$d_2(x)$为两个距离函数计算点x与最近和第二近的cell的距离。
计算损失函数中的指数项会降低训练的速度,因此一般将距离和输入图片一起输入神经网络。
基于重合度的度量
Dice损失或F1分数
Dice系数
根据Lee Raymond Dice命名,是一种集合相似度度量函数,通常用于计算两个样本的相似度(取值范围为[0,1])。
Dice系数和Jaccard index类似:
$|X \bigcap Y|$表示两个样本之间的交集;$|X|$和$|Y|$分别表示X和Y的元素个数,分母的系数为2,用于处理两个集合存在交集的情况。可以看出DC大于IoU。
Dice损失
可以将dice系数定义为损失函数:
其中$p\in \{0,1\}^n$,$\hat p \in [0,1]$。实际计算方式如下:
分子的计算
将$|X \bigcap Y|$近似为预测图和真实标注之间的点乘,点乘的结果进行元素相加:
接着进行元素相加:
分母的计算
可以直接取元素的平方相加,也可以直接进行元素相加。
Dice损失比较适合样本极度不平衡的情况,但会对反向传播造成影响,导致反向传播不稳定。
Dice损失和交叉熵损失的对比及选择
选择交叉熵损失的原因:
交叉熵损失函数相比于基于重合度度量的损失函数,具有更简单的梯度形式,dice损失的梯度形式比较复杂。
选择Dice损失的原因:
该类损失的真实目标就是最大化预测值和真实值的重合度,更为直接,且一般来说,该损失更适用于类别不均衡问题。
weighted soft dice的代码实现:
1 | def criterion_pixel(logit_pixel, truth_pixel): |
Tversky损失
TL损失是DL的改进,该FP和FN加上权重。
将$\beta$设置为$\frac{1}{2}$,则变为DL损失。
交叉熵和Dice损失的结合
也有将交叉熵和Dice损失进行结合的损失函数:
代码如下:
1 | # reference: https://github.com/asanakoy/kaggle_carvana_segmentation |
为了解决样本类别不平衡问题,还可以使用加权BCE+DICE。
1 | class SoftDiceLoss(nn.Module): |