如何解决 Pytorch 中的矩阵维度不匹配?
pytorch 217
原文标题 :How to solve a matrix dimension mismatch in Pytorch?
我正在重建一个神经网络来对 MNIST 数据集进行分类。这个数据集包含手写数字 0-9 的像素图片,标签是对应的数字。我正在尝试通过复制 cifar-10 数据集的 NN 教程来完成它(https:// colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/17a7c7cb80916fcdf921097825a0f562/cifar10_tutorial.ipynb#scrollTo=WygR9nmCEtFT)。以下是网络的主要部分。
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
以上是针对 cifar-10 的。我将 conv1 中的 (3,6,5) 更改为 (1,6,5) 以适应单通道 MNIST 图像。但是,当我运行该程序时,它会引发此错误:
RuntimeError Traceback (most recent call last)
<ipython-input-15-7b5795e8cda9> in <module>()
10
11 # forward + backward + optimize
---> 12 outputs = net(inputs)
13 loss = criterion(outputs, labels)
14 loss.backward()
4 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
<ipython-input-14-bc1756c36393> in forward(self, x)
17 x = self.pool(F.relu(self.conv2(x)))
18 x = torch.flatten(x, 1) # flatten all dimensions except batch
---> 19 x = F.relu(self.fc1(x))
20 x = F.relu(self.fc2(x))
21 x = self.fc3(x)
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py in forward(self, input)
101
102 def forward(self, input: Tensor) -> Tensor:
--> 103 return F.linear(input, self.weight, self.bias)
104
105 def extra_repr(self) -> str:
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
1846 if has_torch_function_variadic(input, weight, bias):
1847 return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848 return torch._C._nn.linear(input, weight, bias)
1849
1850
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x256 and 400x120)
我不知道问题究竟出在哪里以及如何解决它。如果您愿意提供帮助,我们将不胜感激!谢谢!