torch.where documentation指出 x 和 y 可以是张量或标量。但是,它似乎不支持 float32 标量。
import torch
x = torch.randn(3, 2) # x is of type torch.float32
torch.where(x>0, 0, x) # RuntimeError: expected scalar type long long but found float
# torch.where(x>0, 0.0, x) # RuntimeError: expected scalar type double but found float
我的问题是如何使用 float32 标量?
最佳答案
并不是说torch
不支持float32
。这是您的系统没有提供将 0
指定为 float32
的简单方法。如错误中所述,0
被解释为 long long
C 类型,即 int64
,而 0.0
是解释为 double
C 类型,即 float64
。
我想您需要将 0
转换为与 x
相同的 dtype
:
torch.where(x>0.0, torch.tensor(0, dtype=x.dtype), x)
关于pytorch - torch.where 中的标量类型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66299149/