python - Pytorch 期望类型为 Long 但得到类型为 int

标签 python python-3.x pytorch tensor

我恢复了一个错误

 Expected object of scalar type Long but got scalar type Int for argument #3 'index'

这是来自这一行。

targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)

我不知道该怎么做,因为我试图使用几个地方将其转换为 long 。我尝试放一个

.long

最后将 dtype 设置为 torch.long 仍然不起作用。

与此非常相似,但他没有做任何事情来得到答案 "Expected Long but got Int" while running PyTorch script

我更改了很多代码,这是我的最后一次演绎,但现在给了我同样的问题。

    def forward(self, inputs, targets):
            """
            Args:
                inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
                targets: ground truth labels with shape (num_classes)
            """
            log_probs = self.logsoftmax(inputs)
            targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
            if self.use_gpu: targets = targets.to(torch.device('cuda'))
            targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
            loss = (- targets * log_probs).mean(0).sum()
            return loss

最佳答案

索引参数的数据类型(即 targets.unsqueeze(1).data.cpu())需要为 torch.int64

(错误消息有点令人困惑:torch.long 不存在。但 PyTorch 内部的“Long”表示 int64)。

关于python - Pytorch 期望类型为 Long 但得到类型为 int,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56391202/

相关文章:

python - 向具有非数字 x 轴的多个子图添加一条水平线

python - 将张量列表转换为张量

python - 根据 MultiIndex DataFrame 中的第一级列删除重复项

python - 自定义排序 pandas 数据框

Python:如何拆分值并按类别查找平均值

json - 如何将每 n 行添加到 n 列?

python - pytorch中的一项热门编码文本数据

python - 我如何使用 Anaconda 卸载 PyTorch?

python - Scipy Weibull 参数置信区间

python - 通过重复字典键值拆分 python 字典列表