有没有办法在pytorch的损失函数中包含一个计数器(一个计算某些东西的变量)?

青葱年少 pytorch 516

原文标题is there any way to include a counter(a variable that count something) in a loss function in pytorch?

这些是我的损失函数中的一些行。output是多类分类网络的输出。

bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])

dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)

我想dr_output.sum()成为我损失函数的一部分。但是我的实现有很多限制。有些函数在pytorch中是不可微分的,而且dr_output可能为零,如果我只使用dr_output作为我的损失,这也是不允许的。任何人都可以向我建议解决这些问题的方法吗?

原文链接:https://stackoverflow.com//questions/71418817/is-there-any-way-to-include-a-countera-variable-that-count-something-in-a-loss

回复

我来回复
  • aretor的头像
    aretor 评论

    如果我理解正确:

    bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])
    

    计算每行有多少个元素大于.1

    反而:

    dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)
    

    如果对应行中只有一个大于.1的元素为真,则预测正确。

    dr_output.sum()然后计算有多少行验证了这个条件,因此最小化损失可能会强制执行不正确的预测或更多值大于.1的分布。

    考虑到这些因素,您可以使用以下方法估算您的损失:

    import torch.nn.functional as F
    
    # x are the inputs, y the labels
    
    mask = x > 0.1
    p = F.softmax(x, dim=1)
    out = p * (mask.sum(dim=1, keepdim=True) == 1)
    
    loss = out[torch.arange(x.shape[0]), y].sum()
    

    您可以设计更适合您的问题的类似变体。

    2年前 0条评论