only batches of spatial targets supported (3D tensors) but got targets of dimension

问题产生的原因是使用nn.CrossEntropyLoss()来计算损失的时候,target的维度超过4

import torch
import torch.nn as nn

logit = torch.ones(size=(4, 32, 256, 256))  # b,c,h,w
target = torch.ones(size=(4, 1, 256, 256))

criterion = nn.CrossEntropyLoss()
loss = criterion(logit, target)

如实target中的C不是1,则可以:

import torch
import torch.nn as nn

logit = torch.ones(size=(4, 32, 256, 256))  # b,c,h,w
target = torch.ones(size=(4, 2, 256, 256))

criterion = nn.CrossEntropyLoss()
losses = 0
for i in range(2):
    loss = criterion(logit, target[:, i, ...].long())
    losses += loss

 可以看到代码里面有个.long(),如果不用的话则会报错:

RuntimeError: expected scalar type Long but found Float

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年7月12日
下一篇 2023年7月12日

相关推荐