Pytorch Python 分布式多处理:收集/连接不同长度/大小的张量数组

社会演员多 pytorch 728

原文标题Pytorch Python Distributed Multiprocessing: Gather/Concatenate tensor arrays of different lengths/sizes

如果您在多个 gpu 等级中有不同长度的张量数组,则默认的 all_gather 方法不起作用,因为它要求长度相同。

例如,如果您有:

if gpu == 0:
    q = torch.tensor([1.5, 2.3], device=torch.device(gpu))
else:
    q = torch.tensor([5.3], device=torch.device(gpu))

如果我需要按如下方式收集这两个张量数组:

all_q = [torch.tensor([1.5, 2.3], torch.tensor[5.3])

默认的 torch.all_gather 不起作用,因为长度 2、1 不同。

原文链接:https://stackoverflow.com//questions/71433507/pytorch-python-distributed-multiprocessing-gather-concatenate-tensor-arrays-of

回复

我来回复
  • omsrisagar的头像
    omsrisagar 评论

    由于无法直接使用内置方法进行收集,我们需要编写自定义函数,步骤如下:

    1. 使用 dist.all_gather 获取所有数组的大小。
    2. 找到最大尺寸。
    3. 使用零/常量将本地数组填充到最大大小。
    4. 使用 dist.all_gather 获取所有填充数组。
    5. 使用步骤 1 中找到的大小取消填充添加的零/常量。

    以下函数执行此操作:

    def all_gather(q, ws, device):
        """
        Gathers tensor arrays of different lengths across multiple gpus
        
        Parameters
        ----------
            q : tensor array
            ws : world size
            device : current gpu device
            
        Returns
        -------
            all_q : list of gathered tensor arrays from all the gpus
    
        """
        local_size = torch.tensor(q.size(), device=device)
        all_sizes = [torch.zeros_like(local_size) for _ in range(ws)]
        dist.all_gather(all_sizes, local_size)
        max_size = max(all_sizes)
    
        size_diff = max_size.item() - local_size.item()
        if size_diff:
            padding = torch.zeros(size_diff, device=device, dtype=q.dtype)
            q = torch.cat((q, padding))
    
        all_qs_padded = [torch.zeros_like(q) for _ in range(ws)]
        dist.all_gather(all_qs_padded, q)
        all_qs = []
        for q, size in zip(all_qs_padded, all_sizes):
            all_qs.append(q[:size])
        return all_qs
    

    一旦我们能够完成上述操作,我们就可以在需要时轻松使用 torch.cat 进一步连接成单个数组:

    torch.cat(all_q)
    [torch.tensor([1.5, 2.3, 5.3])
    

    改编自:github

    2年前 0条评论