pytorch框架:conv1d、conv2d的输入数据维度是什么样的

文章目录

  • Conv1d
  • Conv2d

Conv1d

Conv1d 的输入数据维度通常是一个三维张量,形状为 (batch_size, in_channels, sequence_length),其中:

batch_size 表示当前输入数据的批次大小;
in_channels 表示当前输入数据的通道数,对于文本分类任务通常为 1,对于图像分类任务通常为 3(RGB)、1(灰度)等;
sequence_length 表示当前输入数据的序列长度,对于文本分类任务通常为词向量的长度,对于时序信号处理任务通常为时间序列的长度,对于图像分类任务通常为图像的高或宽。
具体来说,Conv1d 模块会对第二维和第三维分别进行一维卷积操作,保留第一维(即批次大小)不变,输出一个新的三维张量,形状为 (batch_size, out_channels, new_sequence_length),其中 out_channels 表示卷积核的数量,new_sequence_length 表示卷积后的序列长度。

示例:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=16, kernel_size=2),
            nn.ReLU(),
            # nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=2),
            nn.ReLU(),
            # nn.MaxPool1d(kernel_size=2)
        )
        self.fc = nn.Linear(128, 2)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
x = torch.randn(200,6)
# x = x.unsqueeze(1)
net = Net()
output = net(x)
print(x.shape)

Conv2d

在 PyTorch 中,使用 nn.Conv2d 创建卷积层时,输入数据的维度应该是 (batch_size, input_channels, height, width)。其中,

batch_size 表示当前输入数据的批次大小;
input_channels 表示当前输入数据的通道数,对于彩色图像通常为 3(RGB),对于灰度图像通常为 1;
height 和 width 分别表示输入数据的高和宽。因此,在 PyTorch 框架中,Conv2d 的输入数据维度应该是一个四维张量,形状为 (batch_size, input_channels, height, width)。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2023年11月9日
下一篇 2023年11月9日

相关推荐