将不同长度的 2D 张量列表转换为一个 3D 张量

乘风 pytorch 330

原文标题Turning list of 2D tensors with different length to one 3D tensor

我有一个包含 3 个形状的张量列表:(8, 2), (8, 4), (8, 6)

我想把这个列表变成这个形状: (8, 3, x)

我该怎么做呢?我知道我需要使用 torch.cat 、 torch.stack 和 torch.transpose 的某种组合,但我想不通。

提前致谢!

原文链接:https://stackoverflow.com//questions/71456457/turning-list-of-2d-tensors-with-different-length-to-one-3d-tensor

回复

我来回复
  • Tomer Geva的头像
    Tomer Geva 评论

    如您所说,您需要使用 torch.cat ,还需要使用 torch.reshape 。假设如下:

    a = torch.rand(8,2)
    b = torch.rand(8,4)
    c = torch.rand(8,6)
    

    并假设确实可以将张量重塑为 (8,3,-1) 形状,其中 -1 表示只要需要,则:

    d = torch.cat((a,b,c), dim=1)
    e = torch.reshape(d, (8,3,-1))
    

    我会解释的。因为第一个维度如果 a,b,c 不同,则串联必须沿着第一个维度,如变量 d 所示。然后,您可以重塑张量,如 e 所示,其中 -1 代表“只要需要”。

    2年前 0条评论