pytorch - torch.where 中的标量类型?

标签 pytorch

torch.where documentation指出 x 和 y 可以是张量或标量。但是,它似乎不支持 float32 标量。

import torch

x = torch.randn(3, 2)  # x is of type torch.float32

torch.where(x>0, 0, x)  # RuntimeError: expected scalar type long long but found float 
# torch.where(x>0, 0.0, x)  # RuntimeError: expected scalar type double but found float

我的问题是如何使用 float32 标量?

最佳答案

并不是说torch不支持float32。这是您的系统没有提供将 0 指定为 float32 的简单方法。如错误中所述,0 被解释为 long long C 类型,即 int64,而 0.0 是解释为 double C 类型,即 float64

我想您需要将 0 转换为与 x 相同的 dtype:

torch.where(x>0.0, torch.tensor(0, dtype=x.dtype), x)

关于pytorch - torch.where 中的标量类型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66299149/

相关文章:

python - Pytorch CNN 的损失没有减少

python - 将张量分配给多个切片

python - ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) 抛出 IndexError : Target 42 is out of bounds

python - pytorch 中的 reshape 和 view 有什么区别?

python - 仅增强 K 折交叉验证中的训练集

python - PyTorch——如何正确使用 "toPILImage"

python - pytorch中引入nn.Parameter的目的

python - 在 virtualenv 中使用 python3.5 导入 torch 时出现段错误(核心转储)

python - 如何在 PyTorch 中平衡(过采样)不平衡数据(使用 WeightedRandomSampler)?

python-3.x - 用户警告 : Implicit dimension choice for log_softmax has been deprecated