在 Pytorch 中查找子张量不等于给定张量的索引
pytorch 290
原文标题 :Find index where a sub-tensor does not equal to a given tensor in Pytorch
例如,我有一个张量,
a = [[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]]
形状(4,4)
。如何找到特定子张量的索引
[-1,-1,-1,-1]
使用 PyTorch 不会出现。我想要得到的预期输出是
[0,2]
回复
我来回复-
BrokenBenchmark 评论
您可以使用
torch.any()
比较张量每一行的元素,然后使用.nonzero()
和.flatten()
生成索引:torch.any(a != torch.Tensor([-1, -1, -1, -1]), axis=1).nonzero().flatten()
例如,
import torch a = torch.Tensor([[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]]) result = torch.any(a != torch.Tensor([-1, -1, -1, -1]), axis=1).nonzero().flatten() print(result)
输出:
tensor([0, 2])
2年前 -
Phoenix 评论
你也可以用
where
或nonzero
:a = torch.Tensor([[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]]) b = torch.Tensor([-1,-1,-1,-1]) result = torch.where(a != b)[0].unique() result = torch.nonzero(a != b, as_tuple=True)[0].unique() print(result)
2年前