解决:RuntimeError: mat1 and mat2 shapes cannot be multiplied (8×256 and 8×256)维度不匹配问题

在设计网络是,前面几层是去噪网络,后边几层是分类网络,因为之前没有接触过分类任务,对全连接层输入维度不太理解,出现错误RuntimeError: mat1 and mat2 shapes cannot be multiplied (8×256 and 8×256)
解决方法:查看上一层卷积的输出值大小,发现
原因:

卷积层的输入为四维[batch_size,channels,H,W] ,而全连接接受维度为2的输入,通常为[batch_size, size]。
所以需要进行变换
添加以下语句:

x = x.view(x.shape[0], -1)
得到大小为([8, 256])

而对于fc层需要根据上面的输出更改输入,及将下面语句的8改为256,跑通

self.fc1 = nn.Linear(8, 256)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年8月3日
下一篇 2023年8月3日

相关推荐