将标签中大于0的像素值(类别)挑选出来。
label = [0,1,2,3]
mask = label > 0
print(mask)
运行时候出现:TypeError: ‘>’ not supported between instances of ‘list’ and ‘int’
因为label是list不能和0比较,所以需要对label格式进行修改。
添加一句:
label = torch.Tensor(label)
或者
label = np.numpy(label)
取决于自己的数据类型,在训练过程中,label已经加载到cuda上了,所以他一定是一个tensor格式,训练时候不必担心。
返回一个布尔类型:
tensor([False, True, True, True])
文章出处登录后可见!
已经登录?立即刷新