RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the

bug:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

源代码如下:

if __name__ == "__main__":
    from torchsummary import summary
    model = UNet()
    print(model)
    summary(model, input_size=(1, 480, 480))

在使用torchsummary可视化模型时候报错,报这个错误是因为类型不匹配,根据报错内容可以看出Input type为torch.FloatTensor(CPU数据类型),而weight type(即网络权重参数这些)为torch.cuda.FloatTensor(GPU数据类型)。

我们将model传到GPU上便可。将代码如下修改便可正常运行:

if __name__ == "__main__":
    from torchsummary import summary
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = UNet().to(device)	# modify
    print(model)
    summary(model, input_size=(1, 480, 480))

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年7月15日
下一篇 2023年7月15日

相关推荐