如何根据 pytorch 中的张量屏蔽前馈层?

扎眼的阳光 pytorch 270

原文标题How do I mask a feed forward layer based on tensor in pytorch?

我有一个非常简单的网络,有 2 个输入(x 和 m)。x 大小为 100m 大小为 3

我的网络简直…

f_1 = linear_layer(x)

f_2 = linear_layer(f_1)
f_3 = linear_layer(f_1)
f_4 = linear_layer(f_1)

f_5 = softmax(linear_layer(sum(f_2, f_3, f_4)))

基于向量 m,我想在最终总和和结果梯度计算中归零并忽略 f_2、f_3、f_4。有没有办法根据向量 m 创建一个蒙版来实现这一点?

原文链接:https://stackoverflow.com//questions/71532067/how-do-i-mask-a-feed-forward-layer-based-on-tensor-in-pytorch

回复

我来回复
  • user947659的头像
    user947659 评论

    好的,这就是你的做法。使用列表推导使其更通用:

    # example input and output
    x = torch.ones(5)
    y = torch.zeros(3)
    
    # mask tensor
    mask = torch.tensor([0, 1, 0])
    
    # initial layer
    z0 = torch.nn.Linear(5, 5)
    
    # layers to potentially mask
    z1 = torch.nn.Linear(5, 3)
    z2 = torch.nn.Linear(5, 3)
    z3 = torch.nn.Linear(5, 3)
    
    # defines how the data passes through the layers, specific mask element is applied to each of the maskable layers
    layer1_output = z0(x)
    layer2_output = mask[0]*z1(layer1_output) + mask[1]*z2(layer1_output) + mask[2]*z3(layer1_output)
    
    # loss function
    loss = torch.nn.functional.binary_cross_entropy_with_logits(layer2_output, y)
    
    # run it and see
    loss.backward()
    print(z0.weight.grad)
    print(z1.weight.grad)
    print(z2.weight.grad)
    print(z3.weight.grad)
    

    如下所示,掩码张量可有效地根据掩码元素选择要应用计算的子网

    tensor([[ 0.0354,  0.0354,  0.0354,  0.0354,  0.0354],
            [-0.0986, -0.0986, -0.0986, -0.0986, -0.0986],
            [-0.0372, -0.0372, -0.0372, -0.0372, -0.0372],
            [-0.0168, -0.0168, -0.0168, -0.0168, -0.0168],
            [-0.0133, -0.0133, -0.0133, -0.0133, -0.0133]])
    tensor([[-0., 0., 0., -0., 0.],
            [-0., 0., 0., -0., 0.],
            [-0., 0., 0., -0., 0.]])
    tensor([[-0.0422,  0.1314,  0.1108, -0.1644,  0.0906],
            [-0.0240,  0.0747,  0.0630, -0.0934,  0.0515],
            [-0.0251,  0.0781,  0.0659, -0.0977,  0.0539]])
    tensor([[-0., 0., 0., -0., 0.],
            [-0., 0., 0., -0., 0.],
            [-0., 0., 0., -0., 0.]])
    
    2年前 0条评论