Pytorch dist.all_gather_object 挂起

青葱年少 pytorch 696

原文标题Pytorch dist.all_gather_object hangs

我正在使用 dist.all_gather_object (PyTorch 1.8 版)从所有 GPU 收集样本 id:

for batch in dataloader:
    video_sns = batch["video_ids"]
    logits = model(batch)
    group_gather_vdnames = [None for _ in range(envs['nGPU'])]
    group_gather_logits = [torch.zeros_like(logits) for _ in range(envs['nGPU'])]
    dist.all_gather(group_gather_logits, logits)
    dist.all_gather_object(group_gather_vdnames, video_sns)

dist.all_gather(group_gather_logits, logits)行正常运行,但程序在dist.all_gather_object(group_gather_vdnames, video_sns)行挂起。

我想知道为什么程序挂在dist.all_gather_object(),我该如何解决?

额外信息:我在具有多个 GPU 的本地计算机上运行我的 ddp 代码。启动脚本是:

export NUM_NODES=1
export NUM_GPUS_PER_NODE=2
export NODE_RANK=0
export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))

python -m torch.distributed.launch \
       --nproc_per_node=$NUM_GPUS_PER_NODE \
       --nnodes=$NUM_NODES \
       --node_rank $NODE_RANK \
       main.py \
       --my_args

原文链接:https://stackoverflow.com//questions/71568524/pytorch-dist-all-gather-object-hangs

回复

我来回复
  • Zhang Yu的头像
    Zhang Yu 评论

    事实证明,我们需要手动设置设备 ID,如dist.all_gather_object()API 的文档字符串中所述。

    添加

    torch.cuda.set_device(envs['LRANK']) # my local gpu_id
    

    并且代码有效。

    我一直认为 GPU ID 是由 PyTorch dist 自动设置的,结果并非如此。

    2年前 0条评论