python - 关于torch.nn.DataParallel的问题

标签 python pytorch

我是深度学习领域的新手。现在我正在复现一篇论文的代码。由于他们使用多个 GPU,因此代码中有一个命令 torch.nn.DataParallel(model, device_ids= args.gpus).cuda()。但是我只有一个GPU,什么 我应该更改此代码以匹配我的 GPU 吗?

谢谢!

最佳答案

DataParallel 也应该在单个 GPU 上工作,但您应该检查 args.gpus 是否仅包含要使用的设备的 ID(应该是0) 或 。 选择 None 将使模块使用所有可用设备。

您也可以删除 DataParallel,因为您不需要它,仅通过调用 model.cuda() 将模型移动到 GPU,或者,如我更喜欢,model.to(device) 其中 device 是设备的名称。

示例:

此示例展示了如何在单个 GPU 上使用模型,使用 .to() 而不是 .cuda() 设置设备。

from torch import nn
import torch

# Set device to cuda if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create model
model = nn.Sequential(
  nn.Conv2d(1,20,5),
  nn.ReLU(),
  nn.Conv2d(20,64,5),
  nn.ReLU()
)

# moving model to GPU
model.to(device)

如果你想使用DataParallel,你可以这样做

# Optional DataParallel, not needed for single GPU usage
model1 = torch.nn.DataParallel(model, device_ids=[0]).to(device)
# Or, using default 'device_ids=None'
model1 = torch.nn.DataParallel(model).to(device)

关于python - 关于torch.nn.DataParallel的问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52663358/

相关文章:

python - 如何在 Windows 上修复 python 的 getaddrinfo 失败

python - PyTorch 函数中的下划线后缀是什么意思?

python - 将 numpy 数组设置为切片而无需任何就地操作

python - 如何缓存 Pytorch 模型以供未连接互联网时使用?

pytorch - 无法找到有效的 cuDNN 算法来运行卷积

numpy - Torch 广播如何为 (8, 8) @ (4, 8, 2) 工作?

python - Tornado 对一个请求推送多个响应

python - "ProgrammingError: column "genre_id "of relation "music_album "does not exist"而该列确实存在

python - 模块未找到错误 : No module named 'SessionState

Python:以编程方式将 Python 包编译为 pyc 或 pyo 文件