在随机批次中优化一个维度上的火炬平均值
pytorch 418
原文标题 :Optimizing torch mean over a dimension in a random batch
我正在寻找一种方法来优化 pytorch 中的以下代码。
我有一个函数f
在空间x,y
和时间t
上定义。
在随机批次中,我需要计算所有相同时间戳的平均值。我能够通过以下低效的 for 循环来实现这一点
import torch
# Space (x,y) and time (t) coordinates in a random batch
x = torch.Tensor([[0, 0, 1, 0],[3, 2, 2, 1],[1,3,5,5]]).T
# compute a dummy function u = f(t,x,y)
f = (x**2 + 0.5)[:,:2]
# timestamps
t = x[:,0]
# get unique timestamps
val = torch.unique(t.squeeze())
for v in val:
# compute a mask for all timestamp equal to v
mask = t == v
# average over the spatial coordinates
f[mask,:] = torch.mean(f[mask,:], dim=0)
print(f)
这导致
f = tensor([[0.5000, 5.1667],
[0.5000, 5.1667],
[1.5000, 4.5000],
[0.5000, 5.1667]])
有没有办法让这个计算更快?
回复
我来回复-
Shai 评论
我想你在找
index_add_
:avg_size = int(t.max().item()) + 1 # number of rows in output tensor z = torch.zeros((avg_size, f.shape[1]), dtype=f.dtype) s = torch.index_add(s, 0, t.long(), f) # sum the elements of f c = torch.index_add(s, 0, t.long(), torch.ones_like(f[:, :1])) # count how many at each entry out = s / c # divide to get the mean
2年前