在 PyTorch 中使用阈值进行训练
pytorch 287
原文标题 :Training with threshold in PyTorch
我有一个神经网络,当输入兴奋时会产生一个值。我需要使用网络返回的这个值来阈值另一个数组。这个阈值操作的结果用于计算损失函数(阈值的值事先不知道,需要通过训练得出)。下面是一个MWE
import torch
x = torch.randn(10, 1) # Say this is the output of the network (10 is my batch size)
data_array = torch.randn(10, 2) # This is the data I need to threshold
ground_truth = torch.randn(10, 2) # This is the ground truth
mse_loss = torch.nn.MSELoss() # Loss function
# Threshold
thresholded_vals = data_array * (data_array >= x) # Returns zero in all places where the value is less than the threshold, the value itself otherwise
# Compute loss and gradients
loss = mse_loss(thresholded_vals, ground_truth)
loss.backward() # Throws error here
由于阈值化操作返回一个没有任何梯度的张量数组,因此backward()
操作会引发错误。
在这种情况下如何训练网络?
回复
我来回复-
FlyingTeller 评论
该回答已被采纳!
您的阈值函数在阈值中不可微,因此
torch
不计算阈值的梯度,这就是您的示例不起作用的原因。import torch x = torch.randn(10, 1, requires_grad=True) # Say this is the output of the network (10 is my batch size) data_array = torch.randn(10, 2, requires_grad=True) # This is the data I need to threshold ground_truth = torch.randn(10, 2) # This is the ground truth mse_loss = torch.nn.MSELoss() # Loss function # Threshold thresholded_vals = data_array * (data_array >= x) # Returns zero in all places where the value is less than the threshold, the value itself otherwise # Compute loss and gradients loss = mse_loss(thresholded_vals, ground_truth) loss.backward() # Throws error here print(x.grad) print(data_array.grad)
输出:
None #<- for the threshold x tensor([[ 0.1088, -0.0617], #<- for the data_array [ 0.1011, 0.0000], [ 0.0000, 0.0000], [-0.0000, -0.0000], [ 0.2047, 0.0973], [-0.0000, 0.2197], [-0.0000, 0.0929], [ 0.1106, 0.2579], [ 0.0743, 0.0880], [ 0.0000, 0.1112]])
2年前