Runtime: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

pytorch框架在写test.py时,报了一个输入,输出类型不一致;

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

输出类型是torch.cuda.FloatTensor,是因为模型训练在gpu上跑的;

所以在加载模型时,指定设备为cpu这样就与输入类型保持一致。指定代码:

map_location=torch.device("cpu")

例如:

module = torch.load("vincent_0.pth", map_location=torch.device("cpu"))

然后问题解决了。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2022年4月20日 下午7:24
下一篇 2022年4月20日 下午7:30

相关推荐