Pytorch dist.all_gather_object 挂起
pytorch 696
原文标题 :Pytorch dist.all_gather_object hangs
我正在使用 dist.all_gather_object (PyTorch 1.8 版)从所有 GPU 收集样本 id:
for batch in dataloader:
video_sns = batch["video_ids"]
logits = model(batch)
group_gather_vdnames = [None for _ in range(envs['nGPU'])]
group_gather_logits = [torch.zeros_like(logits) for _ in range(envs['nGPU'])]
dist.all_gather(group_gather_logits, logits)
dist.all_gather_object(group_gather_vdnames, video_sns)
dist.all_gather(group_gather_logits, logits)
行正常运行,但程序在dist.all_gather_object(group_gather_vdnames, video_sns)
行挂起。
我想知道为什么程序挂在dist.all_gather_object()
,我该如何解决?
额外信息:我在具有多个 GPU 的本地计算机上运行我的 ddp 代码。启动脚本是:
export NUM_NODES=1
export NUM_GPUS_PER_NODE=2
export NODE_RANK=0
export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))
python -m torch.distributed.launch \
--nproc_per_node=$NUM_GPUS_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank $NODE_RANK \
main.py \
--my_args