使用没有循环的坐标张量切片pytorch张量
pytorch 461
原文标题 :Slice pytorch tensor using coordinates tensor without loop
我有一个尺寸为(d1 x d2 x d3 x ... dk
)的张量T
和一个尺寸为(p x q
)的张量I
。这里,I
包含T
但q < k
的坐标,I
的每一列对应T
的一个维度。我有另一个张量V
维度p x di x ...dj
wheresum([di, ..., dj]) = k - q
。 (di, .., dj
) 对应于I
的缺失维度。我要表演T[I] = V
使用numpy
array张贴here[1]的此类问题的具体示例。
解决方案[2]使用了依赖于numpy.index_exp
的精美索引[3]。在pytorch
的情况下,该选项不可用。有没有其他方法可以在不使用循环或将张量转换为numpy
array 的情况下模拟这个 inpytorch
?
下面是一个演示:
import torch
t = torch.randn((32, 16, 60, 64)) # tensor
i0 = torch.randint(0, 32, (10, 1)).to(dtype=torch.long) # indexes for dim=0
i2 = torch.randint(0, 60, (10, 1)).to(dtype=torch.long) # indexes for dim=2
i = torch.cat((i0, i2), 1) # indexes
v = torch.randn((10, 16, 64)) # to be assigned
# t[i0, :, i2, :] = v ?? Obviously this does not work
[1]使用坐标列表切片numpy数组
[2]https://stackoverflow.com/a/42538465/6422069
[3]https://numpy.org/doc/stable/reference/generated/numpy.s_.html
回复
我来回复-
aretor 评论
该回答已被采纳!
经过评论中的一些讨论,我们得出了以下解决方案:
import torch t = torch.randn((32, 16, 60, 64)) # tensor # indices i0 = torch.randint(0, 32, (10,)).to(dtype=torch.long) # indexes for dim=0 i2 = torch.randint(0, 60, (10,)).to(dtype=torch.long) # indexes for dim=2 v = torch.randn((10, 16, 64)) # to be assigned t[(i0, slice(None), i2, slice(None))] = v
2年前