如何每行单独切片 2D Torch 张量?

扎眼的阳光 pytorch 244

原文标题How to slice 2D Torch tensor individually per row?

我在 Pytorch 中有一个想要切片的 2D 张量:

x = torch.rand((3, 5))

在这个例子中,张量有 3 行,我想对 x 进行切片,创建一个新的张量 y,它也有 3 行和 num_col 列。

对我来说具有挑战性的是我想每行切片不同的列。我所拥有的只是 x 、 num_cols 和 idx ,这是一个张量,它保存了从何处切片的起始索引。

示例:我拥有的是 num_cols=2 , idx=[1,2,3] 和

x=torch.arange(15).reshape((3,-1)) =
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])

我想要的是

y=
tensor([[ 1,  2],
        [ 7,  8],
        [13, 14]])

这样做的“火炬”方式是什么?我知道,如果我以某种方式获得布尔掩码,我可以切片,但我不知道如何在没有正常 Python 循环的情况下使用 idx 和 num_cols 构造它。

原文链接:https://stackoverflow.com//questions/71425677/how-to-slice-2d-torch-tensor-individually-per-row

回复

我来回复
  • Kevin的头像
    Kevin 评论

    您可以将花式索引与广播一起使用。另一种解决方案可能是使用类似于 numpy 的 take_along_axis 的 torch.gather 。您的 idx 数组需要使用额外的列进行扩展:

    x = torch.arange(15).reshape(3,-1)
    idx = torch.tensor([1,2,3])
    
    idx = torch.column_stack([idx, idx+1])
    torch.gather(x, 1, idx)
    

    输出:

    tensor([[ 1,  2],
            [ 7,  8],
            [13, 14]])
    
    2年前 0条评论