在PyTorch中,forward()是如何被调用的?了解 __call__, __init__,forward,super()

forward()是怎么被调用的

一、问题描述

  看了一个源码,从最开始看到最后就看到了每个函数里面都有一个def forward()方法。但是,没有看到调用的地方,甚至是参数,和方法的参数都不一样了。
  那代码怎么看呐?从网上看了好多的方法,总是差一点,没有那么明显,就要到了重要的地方了,就结束了。

小Tip-段首缩进

写博客的时候你会发现,段首没有办法缩进;
如果使用两个Tab键/四个空格键的话,就成了下面这样:

	1111111111111
2222222

解决方法:使用特殊占位符,不同占位符所占空白是不一样大的。

  or    表示一个半角的空格
  or    表示一个全角的空格
    两个全角的空格(用的比较多)
  or    不断行的空白格

效果展示

  表示一个半角的空格
  表示一个全角的空格
   两个全角的空格(用的比较多)
  不断行的空白格

二、实例实现方法

1.正常的Python类的执行顺序

class Foo():
    def __init__(self, x):
        print("this is class of Foo.")
        print("Foo类属性初始化")
        self.f = "foo"
        print(x)

class Children(Foo):
    def __init__(self):
        y = 1
        print("this is class of Children.")
        print("Children类属性初始化")
        super(Children, self).__init__(y)  
        # 进入`Foo类`中,向`Foo类`中传入参数`y`,同时初始化`Foo类属性`。
        
a = Children()
print(***)
print(a)
print(***)
print(a.f)

输出结果

***
this is class of Children.
Children类属性初始化
this is class of Foo.
Foo类属性初始化
1
***
foo

执行顺序
创建实例化对象:a = Children()

执行print(a)–>进入Childern类–>初始化Childern类参数,执行def __init__(self):函数 –>进入Children父类Foo,传入参数y并初始化父类Foo参数super(Children, self).__init__(y),执行Foo中的参数初始化。

2.在神经网络中的应用

2.1 参考一段代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        #1个输入图像通道,6个输出通道,3x3平方卷积核
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension 
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
 
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
 
net = Net()
print(net)

输出结果

>>>Net(
 
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
 
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
 
  (fc1): Linear(in_features=576, out_features=120, bias=True)
 
  (fc2): Linear(in_features=120, out_features=84, bias=True)
 
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
总结:

那关于forward是怎么被调用的,模型训练时,不需要使用forward,只要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数

2.2 接下来我们看几个例子,了解一下__call__, init,forward

# list 1
class A():
    def __call__(self):
        print('i can be called like a function')
 
a = A()
a()

输出结果

>>> i can be called like a function
# list 2
class B():
    def __init__(self):
        print('i can be called like a function')
 
a = B()
print(a)

输出结果:

>>>i can be called like a function
>>><__main__.B at 0x5978fd0>
# list 3
class C():
    def __init__(self):
        print('a')
    def __call__(self):
        print('b')    
 
a = C()
a()

输出结果:

>>>a
>>>b
 # list 4
class D():
    def __init__(self, init_age):
        print('我年龄是:',init_age)
        self.age = init_age
 
    def __call__(self, added_age):
        res = self.forward(added_age)
        return res
 
    def forward(self, input_):
        print('forward 函数被调用了')
        
        return input_ + self.age
print('对象初始化。。。。')
a = D(10)
print("*************split line*************") 
input_param = a(2)
print("我现在的年龄是:", input_param)

输出结果:

>>>对象初始化。。。。
我年龄是: 10
*************split line*************
forward 函数被调用了
我现在的年龄是: 12
总结:

pytorch主要也是就是按照__call__, init,forward三个函数实现网络层之间的架构的,所以从以上例子可看出,定义__call__方法的类可以当作函数调用,当把定义的网络模型model当作函数调用的时候就自动调用定义的网络模型的forward方法。

2.3 上面提到了super继承,也需要了解一下

# list 1
class Net(nn.Module):
    def __init__(self):
        print("this is Net")
        self.a = 1
        super(Net, self).__init__()

    def forward(self, x):
        print("this is forward of Net", x)

class Children(Net):
    def __init__(self):
        print("this is children")
        super(Children, self).__init__()
a = Children()
print("*************split line*************") 
a(1)

输出结果:

this is children
this is Net
*************split line*************
this is forward of Net 1
总结:

上面代码执行顺序:

创建实例化对象a = Children()
执行a:进入Childern类 –>初始化Childern类参数,执行def __init__(self): –>进入Children父类Net, 并初始化父类Net参数super(Children, self).__init__(),执行Net中的参数初始化 –>传入参数并执行父类Net的forward()函数

2.4 比较好的一个例子带有nn.Module源码部分

class FPN(nn.Module):
	def __init__(self,in_channels_list,out_channels):
		self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1)
        self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1)
        self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1)
        
    def forward(self, input):
    	input = list(input.values())
    	
    	output1 = self.output1(input[0])
        output2 = self.output2(input[1])
        output3 = self.output3(input[2])
		
		out = [output1, output2, output3]
		return out
class RetinaFace(nn.Module):
	def __init__(self):
		# 定义层结构,举例如下
		self.fpn = FPN()
	def forward(self, inputs):
		out = self.fpn(inputs)
		return out

net = RetinaFace()
out = net(image) # 图像作为输入,经过net做正向传播,得到输出(分类/框/。。。)	

看一下nn.Module的定义(截取有用部分):

class Module(object):
	def forward(self, *input):
	        r"""Defines the computation performed at every call.
	
	        Should be overridden by all subclasses.
	
	        .. note::
	            Although the recipe for forward pass needs to be defined within
	            this function, one should call the :class:`Module` instance afterwards
	            instead of this since the former takes care of running the
	            registered hooks while the latter silently ignores them.
	        """
	        raise NotImplementedError

	def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs) # 重点!!
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if len(self._backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in self._backward_hooks.values():
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result
总结:
  1. forward是在__call__中调用的,而__call__函数是在类的对象使用‘()’时被调用(如:此例的net(image))。一般调用在类中定义的函数的方法是:example_class_instance.func(),如果只是使用example_class_instance(),那么这个操作就是在调用__call__这个内置方法
  2. 分析此例:
    从第一步来看,out = net(image)实际上就是调用了net__call__方法net__call__方法没有显式定义,那么就使用它的父类方法,也就是调用nn.Module的__call__方法,它调用了forward方法,又有,net类中定义了forward方法,所以使用重写的forward方法
  3. 依次类推,out = self.fpn(inputs)也是先调用__call__方法,它进一步调用forward方法,而forward方法被FPN类重写,故调用重写后的forward方法
  4. __call__方法中调用forward方法,由于每个网络如上述的RetinaFace和FPN都重写了forward方法,所以,当调用forward时,都调用的是重写之后的版本。至此,回答了上述问题。

2.5 在pytorch 中没有调用模型的forward()前向传播,只实列化后把参数传入

定义模型:

class Module(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        # ......
  
    def forward(self, x):
        # ......
        return x

data = .....  #输入数据
# 实例化一个对象
module = Module()
#  前向传播 直接把输入传入实列化
module(data)  
#没有使用module.forward(data)  

#实际上module(data)  等价于module.forward(data)   

等价的原因是因为 python calss 中的__call__ 可以让类像函数一样调用,当执行model(x)的时候,底层自动调用forward方法计算结果。

在__call__ 里可调用其它的函数:

class A():
    def __call__(self, param):
        
        print('我在__call__中,传入参数',param)
 
        res = self.forward(param)
        return res
 
    def forward(self, x):
        print('我在forward函数中,传入参数类型是值为: ',x)
        return x
 
a = A()
y = a('i')
print("*****")
print("传入的参数是:", y)

输出结果:

  >>> 我在__call__中,传入参数 i
  >>>我在forward函数中,传入参数类型是值为:  i
  >>>*****
  >>>传入的参数是: i

2.6 实例

nn.Module 的__call__方法部分源码如下所示:

def __call__(self, *input, **kwargs):
   result = self.forward(*input, **kwargs)
   for hook in self._forward_hooks.values():
       #将注册的hook拿出来用
       hook_result = hook(self, input, result)
   ...
   return result

可以看到,当执行model(x)的时候,底层自动调用forward方法计算结果。

class Animal():
    def __init__(self,name): 
        self.name = name
    def greet(self):
        print('animal is %s' % self.name)
class Dog(Animal):
    def greet(self):
        super(Dog, self).greet()
        print('wangwang')
a=Dog('dog')
a.greet()

输出结果:

>>>animal is dog
>>>wangwang
总结:
那么调用forward方法的具体流程是什么样的呢?具体流程是这样的:

以一个Module为例:
1. 调用Module的call方法
2. Module的call里面调用Module的forward方法
3. forward里面如果碰到Module的子类,回到第1步,如果碰到的是Function的子类,继续往下
4. 调用Function的call方法
5. Function的call方法调用了Function的forward方法。
6. Function的forward返回值
7. Module的forward返回值
8. 在Module的call进行forward_hook操作,然后返回值。

参考链接

  1. 参考链接一
  2. 参考链接二
  3. 参考链接三
  4. 参考链接四
  5. 参考链接五

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(2)
青葱年少的头像青葱年少普通用户
上一篇 2022年5月19日
下一篇 2022年5月19日

相关推荐