如何解决 Pytorch 中的矩阵维度不匹配?

社会演员多 pytorch 217

原文标题How to solve a matrix dimension mismatch in Pytorch?

我正在重建一个神经网络来对 MNIST 数据集进行分类。这个数据集包含手写数字 0-9 的像素图片,标签是对应的数字。我正在尝试通过复制 cifar-10 数据集的 NN 教程来完成它(https:// colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/17a7c7cb80916fcdf921097825a0f562/cifar10_tutorial.ipynb#scrollTo=WygR9nmCEtFT)。以下是网络的主要部分。

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()


import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

以上是针对 cifar-10 的。我将 conv1 中的 (3,6,5) 更改为 (1,6,5) 以适应单通道 MNIST 图像。但是,当我运行该程序时,它会引发此错误:

RuntimeError                              Traceback (most recent call last)
<ipython-input-15-7b5795e8cda9> in <module>()
     10 
     11         # forward + backward + optimize
---> 12         outputs = net(inputs)
     13         loss = criterion(outputs, labels)
     14         loss.backward()

4 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-14-bc1756c36393> in forward(self, x)
     17         x = self.pool(F.relu(self.conv2(x)))
     18         x = torch.flatten(x, 1) # flatten all dimensions except batch
---> 19         x = F.relu(self.fc1(x))
     20         x = F.relu(self.fc2(x))
     21         x = self.fc3(x)

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py in forward(self, input)
    101 
    102     def forward(self, input: Tensor) -> Tensor:
--> 103         return F.linear(input, self.weight, self.bias)
    104 
    105     def extra_repr(self) -> str:

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
   1846     if has_torch_function_variadic(input, weight, bias):
   1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
   1849 
   1850 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x256 and 400x120)

我不知道问题究竟出在哪里以及如何解决它。如果您愿意提供帮助,我们将不胜感激!谢谢!

原文链接:https://stackoverflow.com//questions/71976528/how-to-solve-a-matrix-dimension-mismatch-in-pytorch

回复

我来回复
  • Aramakus的头像
    Aramakus 评论

    问题是输入通道数(400)在这里是错误的

    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    

    应该是 256

    self.fc1 = nn.Linear(256, 120)
    
    2年前 0条评论