python - 类型错误 : 'int' object is not callable in loss. 向后()

标签 python pytorch

在尝试设置 pytorch 模型时,我收到错误消息,表示在尝试执行 Pytorch autograd 时损失对象不可调用。 (相关代码如下所示)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, 
  betas(0.0,0.9))

def train(epoch, shuffle, wisdom_model, optim, loss):
    print('train')
    accuracy = 0
    batch_num = 0
    wisdom_model.train()
    for batch in data.train_dl:

        optim.zero_grad()

        result = model(batch[0])
        loss = nn.CrossEntropyLoss()(result, batch[1].long())

        loss.backward()

        accuracy += accuracy(result, batch[1])
        print(accuracy)
        pdb.set_trace()
        batch_num += 1

    return accuracy / batch_num
TypeError                                 Traceback (most recent call last)
<ipython-input-28-5b9c9fe3b320> in <module>
----> 1 run(1, False)

<ipython-input-27-d0d67dbf6eb2> in run(num_models, dropout)
     71 
     72     for epoch in range(10):
---> 73         train_accuracy = train(epoch, False, model, optimizer, loss)
     74         accuracy.append(validate(epoch, model))
     75 

<ipython-input-27-d0d67dbf6eb2> in train(epoch, shuffle, model, optim, loss)
     24         pdb.set_trace()
     25 
---> 26         loss.backward()
     27         optim.step()
     28 

TypeError: 'int' object is not callable

最佳答案

问题出在这一行:

loss = nn.CrossEntropyLoss()(result, batch[1].long())

查看 nn.CrossEntropyLoss .

不应该是这样的:

nn.CrossEntropyLoss()()

应该看起来像这样:

nn.CrossEntropyLoss()

关于python - 类型错误 : 'int' object is not callable in loss. 向后(),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57838192/

相关文章:

python - 如何在python中以这种格式输出文件名及其单词内容?

python - 我如何使用 pyclustering 来实现 kmedoids?

python - mac 上的多处理导入模块

machine-learning - 在 Pytorch 中下载预训练的 GAN 模型时出错 : 'memory' file not found

python - 如何阻止 Python Sci-kit 库的 Count Vectorizer 进行任何类型的单词过滤?

python - 如何在 Python 3 中根据字符串参数捕获一般异常?

machine-learning - 神经网络的 4d 输入张量与 1d 输入张量(又名向量)

python - 您如何有效地将一个数组中某个值在另一个数组中的位置出现的次数相加

python - 如何修复 google colab 上的 cuda 运行时错误?

python - 如何修复数据集以返回所需的输出(pytorch)