我正在使用 dist.all_gather_object(PyTorch 版本 1.8)从所有 GPU 收集样本 ID:
for batch in dataloader:
video_sns = batch["video_ids"]
logits = model(batch)
group_gather_vdnames = [None for _ in range(envs['nGPU'])]
group_gather_logits = [torch.zeros_like(logits) for _ in range(envs['nGPU'])]
dist.all_gather(group_gather_logits, logits)
dist.all_gather_object(group_gather_vdnames, video_sns)
行 dist.all_gather(group_gather_logits, logits)
工作正常,
但程序卡在 dist.all_gather_object(group_gather_vdnames, video_sns)
行。
我想知道为什么程序在 dist.all_gather_object()
挂起,我该如何修复它?
额外信息: 我在具有多个 GPU 的本地计算机上运行 ddp 代码。启动脚本是:
export NUM_NODES=1
export NUM_GPUS_PER_NODE=2
export NODE_RANK=0
export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))
python -m torch.distributed.launch \
--nproc_per_node=$NUM_GPUS_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank $NODE_RANK \
main.py \
--my_args
最佳答案
事实证明,我们需要手动设置设备 ID,如 docstring 中所述。 dist.all_gather_object()
API。
添加
torch.cuda.set_device(envs['LRANK']) # my local gpu_id
并且代码可以工作。
我一直以为 GPU ID 是由 PyTorch dist 自动设置的,事实证明并非如此。
关于deep-learning - Pytorch dist.all_gather_object 挂起,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71568524/