python - 为什么我的简单 pytorch 网络不能在 GPU 设备上运行?

标签 python image-processing machine-learning deep-learning pytorch

我根据教程构建了一个简单的网络,但遇到了这个错误:

RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #4 'mat1'

有什么帮助吗?谢谢!

import torch
import torchvision

device = torch.device("cuda:0")
root = '.data/'

dataset = torchvision.datasets.MNIST(root, transform=torchvision.transforms.ToTensor(), download=True)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.out = torch.nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.out(x)
        return x

net = Net()
net.to(device)

for i, (inputs, labels) in enumerate(dataloader):
    inputs.to(device)
    out = net(inputs)

最佳答案

长话短说
这是修复

inputs = inputs.to(device)  

为什么?!
torch.nn.Module.to()之间略有不同和 torch.Tensor.to() :虽然 Module.to() 是一个in-place 运算符,但 Tensor.to() 不是。因此

net.to(device)

更改 net 本身并将其移动到 device。另一方面

inputs.to(device)

不会更改inputs,而是返回驻留在device 上的inputs副本。要使用该“在设备上”的副本,您需要将其分配给一个变量,因此

inputs = inputs.to(device)

关于python - 为什么我的简单 pytorch 网络不能在 GPU 设备上运行?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51605893/

相关文章:

python - 从 pandas 数据框转换为 LabeledPoint RDD

c++ - 没有 if-else 语句的一维卷积(非 FFT)?

java - 如何从复杂的研究论文着手编写算法

c++ - 识别开放和封闭的形状opencv

python - 如何保存和重新训练 neat-python 模型?

python - 如何在 Python Flask 中读取文件

python - 如何在 django Rest 框架中使用外键进行计数

machine-learning - 预测 cucumber 收获

r - 为什么我们在 R 中的 model.matrix 函数中提到 -1 ?是为了一种热编码还是有其他原因?

r - 从决策树进行预测的高效算法(使用 R)