python - torch 错误,运行时错误: expected scalar type Long but found Double

标签 python types pytorch

我在训练 BERT 分类器时遇到了以下错误。

type(b_input_mask) = type(b_labels) = torch.Tensor      

type(b_labels[i]) = tensor(1., dtype=torch.float64)

type(b_input_masks[i]) = class'torch.Tensor'

由于我没有将任何变量类型转换为 long 或 double,这里可能存在什么数据类型错误?

提前致谢! Error Stack Trace

最佳答案

在分类任务中,输入标签的数据类型应该是 Long,但您将它们指定为 float64

type(b_labels[i]) = tensor(1., dtype=torch.float64)

=>

type(b_labels[i]) = tensor(1., dtype=torch.long)

关于python - torch 错误,运行时错误: expected scalar type Long but found Double,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62400112/

相关文章:

python - 将不同大小的 torch 张量设置为相等

pytorch - 如何使用 PyTorch 保存和加载神经网络的特定层?

python - 在损失函数中使用 autograd 时 PyTorch 不更新权重

python - Pandas 通过其他列的数量来填充列

python - 在模板中显示自定义 Django 表单错误

按名称调用和按值调用的 Scala 父类(super class)型

c++ - 从通过 decltype 获得的方法类型中剥离类

python - 在 Python Click 库中使用 bool 标志(命令行参数)

python - Tensorflow:具有非负约束的线性回归

C11 标准保证将一种数组类型转换为另一种数组类型