在 pytorch 张量中获取顶部映射值
pytorch 214
原文标题 :get top mapped values in pytorch tensor
我正在尝试获取张量中的前 N 个元素。我有一个映射告诉我如何对张量值进行排序
values_mapping = {1: 12, 3: 1, 4: 2, 2: 34, 12: 3}
tensor = torch.tensor([1, 4, 12, 2])
tensor.topk(3)
这里的结果应该是torch.tensor([1, 12, 2])
i.e。使用values_mapping
映射它们后的最高值
有没有办法使用火炬这样做?我们能告诉torch如何对它得到的值进行排序吗?
回复
我来回复-
Poe Dator 评论
我不知道是否存在更优雅的解决方案,但是您可以使用所需的键选择值,从值中选择前 k 个项目,使用 topk 索引来选择键:
values_mapping = {1: 12, 3: 1, 4: 2, 2: 34, 12: 3} tensor0 = torch.tensor([1, 4, 12, 2]) mapping0 = torch.Tensor([(k, v) for k, v in values_mapping.items() if k in tensor0]) topk = mapping0[:,1].topk(3) top_keys = mapping0[:,0][topk.indices] print(top_keys) >>> tensor([ 2., 1., 12.])
2年前