python - 运行时错误 : Expected object of scalar type Long but got scalar type Float for argument #2 'mat2' how to fix it?

标签 python neural-network pytorch typing


import torch.nn as nn 
import torch 
import torch.optim as optim
import itertools

class net1(nn.Module):
    def __init__(self):
        super(net1,self).__init__()

        self.pipe = nn.Sequential(
            nn.Linear(10,10),
            nn.ReLU()
        )

    def forward(self,x):
        return self.pipe(x.long())

class net2(nn.Module):
    def __init__(self):
        super(net2,self).__init__()

        self.pipe = nn.Sequential(
            nn.Linear(10,20),
            nn.ReLU(),
            nn.Linear(20,10)
        )

    def forward(self,x):
        return self.pipe(x.long())



netFIRST = net1()
netSECOND = net2()

learning_rate = 0.001

opt = optim.Adam(itertools.chain(netFIRST.parameters(),netSECOND.parameters()), lr=learning_rate)

epochs = 1000

x = torch.tensor([1,2,3,4,5,6,7,8,9,10],dtype=torch.long)
y = torch.tensor([10,9,8,7,6,5,4,3,2,1],dtype=torch.long)


for epoch in range(epochs):
    opt.zero_grad()

    prediction = netSECOND(netFIRST(x))
    loss = (y.long() - prediction)**2
    loss.backward()

    print(loss)
    print(prediction)
    opt.step()

错误:

line 49, in prediction = netSECOND(netFIRST(x))

line 1371, in linear; output = input.matmul(weight.t())

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'mat2'

我真的不明白我做错了什么。我试图以各种可能的方式将所有内容都变成 Long。我真的不明白 pytorch 的打字方式。上次我只尝试了一层,它迫使我使用 int 类型。 有人可以解释一下在 pytorch 中如何建立类型以及如何防止和修复这样的错误吗? 很多,我的意思是提前非常感谢,这个问题真的很困扰我,无论我尝试什么,我似乎都无法解决它。

最佳答案

权重是 Floats,输入是 Longs。这是不允许的。事实上,我不认为 torch 支持神经网络中的 float 。

如果您删除所有 对 long 的调用,并将您的输入定义为 float ,它将起作用(它确实有效,我试过了)。

(然后你会得到另一个不相关的错误:你需要总结你的损失)

关于python - 运行时错误 : Expected object of scalar type Long but got scalar type Float for argument #2 'mat2' how to fix it?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58157523/

相关文章:

python - 使用 cairo 在 Python 中绘制大量圆圈

神经网络的 Python 实时图像分类问题

python - 在 Python Tkinter 中,如何在主循环窗口打开后运行代码?

python - 了解 CNN 超参数

python - 我的分类器损失很大,准确率始终为 0

python - 断言错误 : 200 ! = 403

python - val_loss 减半,但 val_acc 保持不变

python - 将张量列表转换为张量的张量 pytorch

python - 导入 PyTorch 时出错 - Python

python - 迭代一段时间后,前向传递速度变慢 10000 倍