我根据教程构建了一个简单的网络,但遇到了这个错误:
RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #4 'mat1'
有什么帮助吗?谢谢!
import torch
import torchvision
device = torch.device("cuda:0")
root = '.data/'
dataset = torchvision.datasets.MNIST(root, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.out = torch.nn.Linear(28*28, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.out(x)
return x
net = Net()
net.to(device)
for i, (inputs, labels) in enumerate(dataloader):
inputs.to(device)
out = net(inputs)
最佳答案
长话短说
这是修复
inputs = inputs.to(device)
为什么?!
torch.nn.Module.to()
之间略有不同和 torch.Tensor.to()
:虽然 Module.to()
是一个in-place 运算符,但 Tensor.to()
不是。因此
net.to(device)
更改 net
本身并将其移动到 device
。另一方面
inputs.to(device)
不会更改inputs
,而是返回驻留在device
上的inputs
的副本。要使用该“在设备上”的副本,您需要将其分配给一个变量,因此
inputs = inputs.to(device)
关于python - 为什么我的简单 pytorch 网络不能在 GPU 设备上运行?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51605893/