对于训练时loss出现负值的情况


在训练时候loss出现负值,就立马停下来分析一下原因在哪。最有可能是损失函数出现问题,开始只使用交叉熵损失时没有出现过,在加上了dice loss时就出现了问题。于是就去dice loss中寻找原因。
1:首先需要明白语义分割的GT,每一个像素点的值就是像素的类别。

# -*- coding: utf-8 -*-
import numpy as np
from torchvision import transforms
import torch
from PIL import Image
img = Image.open('C:/Users/翰墨大人/Desktop/0003_lable.png') #图像所在位置
img1 = np.array(img)
img1 = torch.from_numpy(img1).type(torch.FloatTensor)
# trans = transforms.ToTensor()
# img1 = trans(img1)

a = torch.unique(img1) # 查看图片内的像素值
print(a)
print(img.mode) # 查看图片模式

打印结果:

tensor([ 0.,  1.,  5.,  7.,  8., 26., 29., 38., 40.])
P

原图一共有四十个类别,在003_lable这张GT上只出现了上述的类别。剩下的像素点和上述像素点是重复的。
注意:
这里将np格式转换为tensor格式时候,不能使用transforms.ToTensor(),他会将像素值发生改变。这样新的像素值点和类别就不是一一对应的。

tensor([0.0000, 0.0039, 0.0196, 0.0275, 0.0314, 0.1020, 0.1137, 0.1490, 0.1569])
P

2:语义分割的GT一共有四十类,没有通道,而在模型中pred的输出为一个通道。
在计算二分类dice loss时候,首先要将pred进行sigmoid。GT是四十个类别,如果是二分类的话那么标签就必须是[0,1]。而loss为负值的原因就是没有将标签转换为[0,1]。
X是pred,Y是GT,分子就是X和Y进行矩阵的相乘再相加,当Y中含有大于1的类别,比如30,40的话,而X又是进过sigmoid之后再(0,1)之内,那么分子除以分母的值就会大于1,造成dice loss就变成了负值。

如何将GT变为[0,1]呢?因为我需要的是对GT进行提边,使用一个拉普拉斯对GT进行卷积,然后再使用一个阈值,大于阈值为1,小于为0。为是边缘和不是边缘。GT就又灰度图像转换为了二值图像。

l = F.conv2d(x,sobel,padding=1,stride=1)
print(l.shape)
ll = torch.unique(l) # 查看图片内的像素值
print(ll)

l[l>0.1]=1
l[l<0.1]=0
l_ = torch.unique(l) # 查看图片内的像素值
print(l_)

结果:

对于多分类的dice loss:
GT需要进行one-hot编码,这样一个通道,像素点为(0-40)的GT就会变为四十个通道,每个通道像素点为(0,1),每个通道都可以看做一个二分类问题,属于该类别和不属于该类别。而pred的输出通道也为40,通过计算pred的每个通道和GT的每个通道的loss,最后求均值得到总loss。
3:代码代码参考
单分类代码:

import torch
import torch.nn as nn

class BinaryDiceLoss(nn.Model):
	def __init__(self):
		super(BinaryDiceLoss, self).__init__()
	
	def forward(self, input, targets):
		# 获取每个批次的大小 N
		N = targets.size()[0]
		# 平滑变量
		smooth = 1
		# 将宽高 reshape 到同一纬度
		input_flat = input.view(N, -1)
		targets_flat = targets.view(N, -1)
	
		# 计算交集
		intersection = input_flat * targets_flat 
		N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
		# 计算一个批次中平均每张图的损失
		loss = 1 - dice_eff.sum() / N
		return loss

多分类代码:

import torch
import torch.nn as nn

class MultiClassDiceLoss(nn.Module):
	def __init__(self, weight=None, ignore_index=None, **kwargs):
		super(MultiClassDiceLoss, self).__init__()
		self.weight = weight
		self.ignore_index = ignore_index
		self.kwargs = kwargs
	
	def forward(self, input, target):
		"""
			input tesor of shape = (N, C, H, W)
			target tensor of shape = (N, H, W)
		"""
		# 先将 target 进行 one-hot 处理,转换为 (N, C, H, W)
		nclass = input.shape[1]
		target = one_hot(target.long(), nclass)

		assert input.shape == target.shape, "predict & target shape do not match"
		
		binaryDiceLoss = BinaryDiceLoss()
		total_loss = 0
		
		# 归一化输出
		logits = F.softmax(input, dim=1)
		C = target.shape[1]
		
		# 遍历 channel,得到每个类别的二分类 DiceLoss
		for i in range(C):
			dice_loss = binaryDiceLoss(logits[:, i], target[:, i])
			total_loss += dice_loss
		
		# 每个类别的平均 dice_loss
		return total_loss / C

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年4月6日
下一篇 2023年4月6日

相关推荐