给定任何通用的float torch.Tensor
,可能包含一些NaN值,我正在寻找一种有效的方法来将其中的所有NaN值替换为零,或者完全删除它们并过滤掉另一个新张量中的“有用”值。
我知道执行此操作的一个简单方法是手动迭代给定张量中的所有值(并相应地将它们替换为零或拒绝新张量)。
是否有一些预定义的 Torch 功能或功能组合可以在性能方面更有效地实现这一目标,这依赖于 Torch 固有的 CPU-GPU 优化?
最佳答案
嗯,torch 中似乎没有检查张量是否为 NaN 的函数。但由于 NaN != NaN,有一个解决方法:
a = torch.rand(4, 5)
a[2][3] = tonumber('nan')
nan_mask = a:ne(a)
notnan_mask = a:eq(a)
print(a)
0.2434 0.1731 0.3440 0.3340 0.0519
0.0932 0.4067 nan 0.1827 0.5945
0.3020 0.1035 0.5415 0.3329 0.7881
0.6108 0.9498 0.0406 0.9335 0.3582
[torch.DoubleTensor of size 4x5]
print(nan_mask)
0 0 0 0 0
0 0 1 0 0
0 0 0 0 0
0 0 0 0 0
[torch.ByteTensor of size 4x5]
有了这些掩码,您可以有效地提取 NaN/非 NaN 值并将其替换为您想要的任何值:
print(a[notnan_mask])
...
[torch.DoubleTensor of size 19]
a[nan_mask] = 42
print(a)
0.2434 0.1731 0.3440 0.3340 0.0519
0.0932 0.4067 42.0000 0.1827 0.5945
0.3020 0.1035 0.5415 0.3329 0.7881
0.6108 0.9498 0.0406 0.9335 0.3582
[torch.DoubleTensor of size 4x5]
关于lua - torch 7 : Filtering out NaN values,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37144125/