我在训练 BERT 分类器时遇到了以下错误。
type(b_input_mask) = type(b_labels) = torch.Tensor
type(b_labels[i]) = tensor(1., dtype=torch.float64)
type(b_input_masks[i]) = class'torch.Tensor'
由于我没有将任何变量类型转换为 long 或 double,这里可能存在什么数据类型错误?
最佳答案
在分类任务中,输入标签的数据类型应该是 Long,但您将它们指定为 float64
type(b_labels[i]) = tensor(1., dtype=torch.float64)
=>
type(b_labels[i]) = tensor(1., dtype=torch.long)
关于python - torch 错误,运行时错误: expected scalar type Long but found Double,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62400112/