将不同长度的 2D 张量列表转换为一个 3D 张量
pytorch 330
原文标题 :Turning list of 2D tensors with different length to one 3D tensor
我有一个包含 3 个形状的张量列表:(8, 2), (8, 4), (8, 6)
我想把这个列表变成这个形状: (8, 3, x)
我该怎么做呢?我知道我需要使用 torch.cat 、 torch.stack 和 torch.transpose 的某种组合,但我想不通。
提前致谢!
回复
我来回复-
Tomer Geva 评论
如您所说,您需要使用 torch.cat ,还需要使用 torch.reshape 。假设如下:
a = torch.rand(8,2) b = torch.rand(8,4) c = torch.rand(8,6)
并假设确实可以将张量重塑为 (8,3,-1) 形状,其中 -1 表示只要需要,则:
d = torch.cat((a,b,c), dim=1) e = torch.reshape(d, (8,3,-1))
我会解释的。因为第一个维度如果 a,b,c 不同,则串联必须沿着第一个维度,如变量 d 所示。然后,您可以重塑张量,如 e 所示,其中 -1 代表“只要需要”。
2年前