重新理解一个类中的forward()和__init__()函数

forward()函数和__init__()的关系

__init__() 是一个类的构造函数,用于初始化对象的属性。它会在创建对象时自动调用,而且通常在这里完成对象所需的所有初始化操作。

forward() 是一个神经网络模型中的方法,用于定义数据流的向前传播过程。它接受输入数据,通过网络的各个层进行计算,最终返回输出结果。

在神经网络的 PyTorch 实现中,__init__() 方法通常用于实例化各个网络层(例如卷积层、池化层、全连接层的维度等【这里只是执行了初始化,但是可以通过后面实例化时调用的forward()重新给神经网络维度赋值】),并设置各层的超参数(例如卷积核大小、步幅、填充等)。而 forward() 方法则定义了这些网络层之间的计算顺序与逻辑,它负责将输入数据传递到网络中,并返回计算结果【这里输入进forward的数据维度要和forward()接收的第一个参数维度相同,虽然你看它只接受了一个参数‘x’,但是这个x的维度是多维的(在本代码中就是(input_dim, hidden_dim)两个大维度),而不是普通意义上的一个自然数

因此,两个方法通常一起使用,__init__() 用于设置网络结构和超参数,forward() 则定义了从输入到输出的完整计算流程。

例子:

定义类:

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

在上面的代码中,我们定义了一个名为 SimpleNet 的神经网络模型,它继承自 PyTorch 中的 nn.Module 类。我们在 __init__() 方法中定义了三层网络结构,分别是输入层 fc1、激活层 relu 和输出层 fc2。其中,输入层和输出层都使用了全连接层(nn.Linear),而激活层使用了 ReLU 激活函数。

forward() 方法中,我们按照输入数据 x 经过 fc1relufc2 三层的顺序进行计算,最终返回输出结果 out

调用

调用上述代码的 forward() 方法需要先创建一个 SimpleNet 类的对象,并将输入数据传递给该对象。以下是一个简单的示例:

# 创建一个 SimpleNet 对象,设置输入维度为 10,隐藏层维度为 20,输出维度为 5
net = SimpleNet(10, 20, 5)

# 构造一个随机的输入张量,大小为 [batch_size, input_dim],这里令 batch_size=1
input_tensor = torch.randn(1, 10)

# 将输入张量传入网络中,得到输出张量
output_tensor = net(input_tensor)

# 打印输出张量的形状
print(output_tensor.shape)

为什么上面的代没有看到 __init__()、forword()函数的出现就完成了上述代码的调用呢?

初始化一个类时,则自动调用了该类的 __init__() 方法【net = SimpleNet(10, 20, 5)】

调用一个类的实例时,会自动调用该类的forward() 方法【output_tensor = net(input_tensor)】

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2023年9月2日
下一篇 2023年9月2日

相关推荐