PyTorch – 如何使用形状 [160] 的数组来索引形状 [2, 160, 12, 1024] 的数组以获取 [2, 160, 1024]?
pytorch 253
原文标题 :PyTorch – How to index an array of shape [2, 160, 12, 1024] using an array of shape [160] to get [2, 160, 1024]?
我有一个形状为 [2, 160, 12, 1024] 的张量A
。我有另一个形状为 [160] 的张量B
,其值在 0 到 11 之间,例如。[0,11,3,2,1,2, 6,7,…]
我正在尝试使用B
来索引A
沿着 THIRDaxis 它使得输出为 [2, 160, 1024] 例如
C = np.zeros((2,160,1024))
for i in range(160):
C[:,i,:] = A[:,i,B[i],:]
A[:,B]
不起作用,因为它最终会沿第二个轴进行索引。使用 for 循环的解决方案很慢,因为我有一个明显更大的数组。