lua - torch 7 : Filtering out NaN values

标签 lua nan torch

给定任何通用的float torch.Tensor,可能包含一些NaN值,我正在寻找一种有效的方法来将其中的所有NaN值替换为零,或者完全删除它们并过滤掉另一个新张量中的“有用”值。

我知道执行此操作的一个简单方法是手动迭代给定张量中的所有值(并相应地将它们替换为零或拒绝新张量)。

是否有一些预定义的 Torch 功能或功能组合可以在性能方面更有效地实现这一目标,这依赖于 Torch 固有的 CPU-GPU 优化?

最佳答案

嗯,torch 中似乎没有检查张量是否为 NaN 的函数。但由于 NaN != NaN,有一个解决方法:

a = torch.rand(4, 5)
a[2][3] = tonumber('nan')
nan_mask = a:ne(a)
notnan_mask = a:eq(a)

print(a)
 0.2434  0.1731  0.3440  0.3340  0.0519
 0.0932  0.4067  nan     0.1827  0.5945
 0.3020  0.1035  0.5415  0.3329  0.7881
 0.6108  0.9498  0.0406  0.9335  0.3582
[torch.DoubleTensor of size 4x5]

print(nan_mask)
 0  0  0  0  0
 0  0  1  0  0
 0  0  0  0  0
 0  0  0  0  0
[torch.ByteTensor of size 4x5]

有了这些掩码,您可以有效地提取 NaN/非 NaN 值并将其替换为您想要的任何值:

print(a[notnan_mask])
...
[torch.DoubleTensor of size 19]

a[nan_mask] = 42
print(a)
  0.2434   0.1731   0.3440   0.3340   0.0519
  0.0932   0.4067  42.0000   0.1827   0.5945
  0.3020   0.1035   0.5415   0.3329   0.7881
  0.6108   0.9498   0.0406   0.9335   0.3582
[torch.DoubleTensor of size 4x5]

关于lua - torch 7 : Filtering out NaN values,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37144125/

相关文章:

lua - 在 Lua 中为变量添加值?

c++ - 从Pytorch C++中的c10::Dict <c10::IValue,c10::IValue>获取值

lua - 如何拆分包含子表的 Lua 表

security - Lua - 我如何在不改变 GetFenv 函数的情况下欺骗它?

floating-point - Fortran 中具有 NaN 值的参数(常量)变量

cuda - 在 CUDA 中获取浮点特殊值的方法?

javascript - 自增运算符返回 NaN

docker - 使用 CUDA 在 docker 上运行 torch 表示未找到模块 'cutorch'

python - 如何将 PyTorch 张量的每一行中的重复值清零?

lua - 从 lua 中的一个输入表输出两个表,第一个输出表将包含所有键,第二个将包含输入表的所有值