使用没有循环的坐标张量切片pytorch张量

社会演员多 pytorch 461

原文标题Slice pytorch tensor using coordinates tensor without loop

我有一个尺寸为(d1 x d2 x d3 x ... dk)的张量T和一个尺寸为(p x q)的张量I。这里,I包含Tq < k的坐标,I的每一列对应T的一个维度。我有另一个张量V维度p x di x ...djwheresum([di, ..., dj]) = k - q。 (di, .., dj) 对应于I的缺失维度。我要表演T[I] = V

使用numpyarray张贴here[1]的此类问题的具体示例。

解决方案[2]使用了依赖于numpy.index_exp的精美索引[3]。在pytorch的情况下,该选项不可用。有没有其他方法可以在不使用循环或将张量转换为numpyarray 的情况下模拟这个 inpytorch

下面是一个演示:

import torch
t = torch.randn((32, 16, 60, 64)) # tensor

i0 = torch.randint(0, 32, (10, 1)).to(dtype=torch.long) # indexes for dim=0
i2 = torch.randint(0, 60, (10, 1)).to(dtype=torch.long) # indexes for dim=2

i = torch.cat((i0, i2), 1) # indexes
v = torch.randn((10, 16, 64)) # to be assigned

# t[i0, :, i2, :] = v ?? Obviously this does not work

[1]使用坐标列表切片numpy数组

[2]https://stackoverflow.com/a/42538465/6422069

[3]https://numpy.org/doc/stable/reference/generated/numpy.s_.html

原文链接:https://stackoverflow.com//questions/71479364/slice-pytorch-tensor-using-coordinates-tensor-without-loop

回复

我来回复
  • aretor的头像
    aretor 评论

    经过评论中的一些讨论,我们得出了以下解决方案:

    import torch
    t = torch.randn((32, 16, 60, 64)) # tensor
    
    # indices
    i0 = torch.randint(0, 32, (10,)).to(dtype=torch.long) # indexes for dim=0
    i2 = torch.randint(0, 60, (10,)).to(dtype=torch.long) # indexes for dim=2
    
    v = torch.randn((10, 16, 64)) # to be assigned
    
    t[(i0, slice(None), i2, slice(None))] = v
    
    2年前 0条评论