问题产生的原因是使用nn.CrossEntropyLoss()来计算损失的时候,target的维度超过4
import torch
import torch.nn as nn
logit = torch.ones(size=(4, 32, 256, 256)) # b,c,h,w
target = torch.ones(size=(4, 1, 256, 256))
criterion = nn.CrossEntropyLoss()
loss = criterion(logit, target)
如实target中的C不是1,则可以:
import torch
import torch.nn as nn
logit = torch.ones(size=(4, 32, 256, 256)) # b,c,h,w
target = torch.ones(size=(4, 2, 256, 256))
criterion = nn.CrossEntropyLoss()
losses = 0
for i in range(2):
loss = criterion(logit, target[:, i, ...].long())
losses += loss
可以看到代码里面有个.long(),如果不用的话则会报错:
RuntimeError: expected scalar type Long but found Float
文章出处登录后可见!
已经登录?立即刷新