Pytorch 中如何使用 self() 来生成预测?

xiaoxingxing pytorch 237

原文标题How is self() used in Pytorch to generate predictions?

class MNIST_model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(input_size, num_classes)

    def forward(self, xb):
        xb = xb.reshape(-1, 28 * 28)
        out = self.linear(xb)
        return out
    
    def training_step(self, batch):
        images, labels = batch
        out = self(images)
        loss = F.cross_entropy(out, labels)
        return loss

我在关注 Freecodecamp 教程。在 training_step 方法中,教程中说 = self(images) 用于 Generate Predictions 。

我无法理解如何使用 self 来获得预测。

原文链接:https://stackoverflow.com//questions/71427254/how-is-self-used-in-pytorch-to-generate-predictions

回复

我来回复
  • Ivan的头像
    Ivan 评论

    这实际上并不是 PyTorch 特有的,而是 Python 的工作方式。在对象上使用括号或直接在该类内部使用括号将调用名为__call__的特殊 Python 函数。此函数可用于您的类,因为您继承自实现它的nn.Module为你。

    这是 PyTorch 之外的一个最小示例:

    class A():
        def __call__(self):
            print('calling the object')
    
        def foo(self):
            self()
    

    然后

    >>> x = A()
    >>> x.foo() # prints out "calling the object"
    
    2年前 0条评论