在 PyTorch 中使用阈值进行训练

扎眼的阳光 pytorch 287

原文标题Training with threshold in PyTorch

我有一个神经网络,当输入兴奋时会产生一个值。我需要使用网络返回的这个值来阈值另一个数组。这个阈值操作的结果用于计算损失函数(阈值的值事先不知道,需要通过训练得出)。下面是一个MWE

import torch

x = torch.randn(10, 1)  # Say this is the output of the network (10 is my batch size)
data_array = torch.randn(10, 2)  # This is the data I need to threshold
ground_truth = torch.randn(10, 2)  # This is the ground truth
mse_loss = torch.nn.MSELoss()  # Loss function

# Threshold
thresholded_vals = data_array * (data_array >= x)  # Returns zero in all places where the value is less than the threshold, the value itself otherwise

# Compute loss and gradients
loss = mse_loss(thresholded_vals, ground_truth)
loss.backward()  # Throws error here

由于阈值化操作返回一个没有任何梯度的张量数组,因此backward()操作会引发错误。

在这种情况下如何训练网络?

原文链接:https://stackoverflow.com//questions/71409752/training-with-threshold-in-pytorch

回复

我来回复
  • FlyingTeller的头像
    FlyingTeller 评论

    您的阈值函数在阈值中不可微,因此torch不计算阈值的梯度,这就是您的示例不起作用的原因。

    import torch
    
    x = torch.randn(10, 1, requires_grad=True)  # Say this is the output of the network (10 is my batch size)
    data_array = torch.randn(10, 2, requires_grad=True)  # This is the data I need to threshold
    ground_truth = torch.randn(10, 2)  # This is the ground truth
    mse_loss = torch.nn.MSELoss()  # Loss function
    
    # Threshold
    thresholded_vals = data_array * (data_array >= x)  # Returns zero in all places where the value is less than the threshold, the value itself otherwise
    
    # Compute loss and gradients
    loss = mse_loss(thresholded_vals, ground_truth)
    loss.backward()  # Throws error here
    print(x.grad)
    print(data_array.grad)
    

    输出:

    None #<- for the threshold x
    tensor([[ 0.1088, -0.0617],  #<- for the data_array
            [ 0.1011,  0.0000],
            [ 0.0000,  0.0000],
            [-0.0000, -0.0000],
            [ 0.2047,  0.0973],
            [-0.0000,  0.2197],
            [-0.0000,  0.0929],
            [ 0.1106,  0.2579],
            [ 0.0743,  0.0880],
            [ 0.0000,  0.1112]])
    
    2年前 0条评论