在 Pytorch 中查找子张量不等于给定张量的索引

青葱年少 pytorch 290

原文标题Find index where a sub-tensor does not equal to a given tensor in Pytorch

例如,我有一个张量,

a = [[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]]

形状(4,4)。如何找到特定子张量的索引

[-1,-1,-1,-1] 

使用 PyTorch 不会出现。我想要得到的预期输出是

[0,2]

原文链接:https://stackoverflow.com//questions/71595684/find-index-where-a-sub-tensor-does-not-equal-to-a-given-tensor-in-pytorch

回复

我来回复
  • BrokenBenchmark的头像
    BrokenBenchmark 评论

    您可以使用torch.any()比较张量每一行的元素,然后使用.nonzero().flatten()生成索引:

    torch.any(a != torch.Tensor([-1, -1, -1, -1]), axis=1).nonzero().flatten()
    

    例如,

    import torch
    
    a = torch.Tensor([[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]])
    result = torch.any(a != torch.Tensor([-1, -1, -1, -1]), axis=1).nonzero().flatten()
    print(result)
    

    输出:

    tensor([0, 2])
    
    2年前 0条评论
  • Phoenix的头像
    Phoenix 评论

    你也可以用wherenonzero

    a = torch.Tensor([[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]])
    
    b = torch.Tensor([-1,-1,-1,-1])
    
    result = torch.where(a != b)[0].unique()
    
    result = torch.nonzero(a != b, as_tuple=True)[0].unique()
    
    print(result)
    
    2年前 0条评论