Pytorch Python 分布式多处理:收集/连接不同长度/大小的张量数组
pytorch 728
原文标题 :Pytorch Python Distributed Multiprocessing: Gather/Concatenate tensor arrays of different lengths/sizes
如果您在多个 gpu 等级中有不同长度的张量数组,则默认的 all_gather 方法不起作用,因为它要求长度相同。
例如,如果您有:
if gpu == 0:
q = torch.tensor([1.5, 2.3], device=torch.device(gpu))
else:
q = torch.tensor([5.3], device=torch.device(gpu))
如果我需要按如下方式收集这两个张量数组:
all_q = [torch.tensor([1.5, 2.3], torch.tensor[5.3])
默认的 torch.all_gather 不起作用,因为长度 2、1 不同。
回复
我来回复-
omsrisagar 评论
由于无法直接使用内置方法进行收集,我们需要编写自定义函数,步骤如下:
- 使用 dist.all_gather 获取所有数组的大小。
- 找到最大尺寸。
- 使用零/常量将本地数组填充到最大大小。
- 使用 dist.all_gather 获取所有填充数组。
- 使用步骤 1 中找到的大小取消填充添加的零/常量。
以下函数执行此操作:
def all_gather(q, ws, device): """ Gathers tensor arrays of different lengths across multiple gpus Parameters ---------- q : tensor array ws : world size device : current gpu device Returns ------- all_q : list of gathered tensor arrays from all the gpus """ local_size = torch.tensor(q.size(), device=device) all_sizes = [torch.zeros_like(local_size) for _ in range(ws)] dist.all_gather(all_sizes, local_size) max_size = max(all_sizes) size_diff = max_size.item() - local_size.item() if size_diff: padding = torch.zeros(size_diff, device=device, dtype=q.dtype) q = torch.cat((q, padding)) all_qs_padded = [torch.zeros_like(q) for _ in range(ws)] dist.all_gather(all_qs_padded, q) all_qs = [] for q, size in zip(all_qs_padded, all_sizes): all_qs.append(q[:size]) return all_qs
一旦我们能够完成上述操作,我们就可以在需要时轻松使用 torch.cat 进一步连接成单个数组:
torch.cat(all_q) [torch.tensor([1.5, 2.3, 5.3])
改编自:github
2年前