python - 从 PyTorch N 维张量中过滤出 NaN 值

标签 python python-3.x pytorch filtering nan

这个问题很相似to filtering np.nan values from pytorch in a -Dimensional tensor .不同之处在于我想将相同的概念应用于 2 维或更高维的张量。
我有一个看起来像这样的张量:

import torch

tensor = torch.Tensor(
[[1, 1, 1, 1, 1],
 [float('nan'), float('nan'), float('nan'), float('nan'), float('nan')],
 [2, 2, 2, 2, 2]]
)
>>> tensor.shape
>>> [3, 5]
我想找到最 Pythonic/PyTorch 的方法来过滤(删除)张量的行 nan .通过过滤此 tensor沿着第一个(0 th 轴)我想获得一个 filtered_tensor看起来像这样:
>>> print(filtered_tensor)
>>> torch.Tensor(
[[1, 1, 1, 1, 1],
 [2, 2, 2, 2, 2]]
)
>>> filtered_tensor.shape
>>> [2, 5]

最佳答案

使用 PyTorch 的 isnan()连同any()切片 tensor的行使用获得的 bool 掩码如下:

filtered_tensor = tensor[~torch.any(tensor.isnan(),dim=1)]
请注意,这将删除具有 nan 的任何行。其中的值(value)。如果您只想删除所有值为 nan 的行替换 torch.anytorch.all .
对于 N 维张量,您可以将除第一个暗淡之外的所有暗淡变平并应用与上述相同的过程:
#Flatten:
shape = tensor.shape
tensor_reshaped = tensor.reshape(shape[0],-1)
#Drop all rows containing any nan:
tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)]
#Reshape back:
tensor = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:])

关于python - 从 PyTorch N 维张量中过滤出 NaN 值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64594493/

相关文章:

Python:使用字典从列表中删除重复项同时保留顺序

python - 超出 IOPub 数据速率的 Jupyter Notebook 错误消息

regex - 使用字典替换字符串 - 正则表达式

python - PyTorch 中相同形状的掩蔽张量

python - 通过压缩将 2 个 numpy 维度 reshape 为一个维度?

pytorch - 如何禁用 TOKENIZERS_PARALLELISM=(true | false) 警告?

python - 如何在pycharm中加载pcd文件

python - 在 PyCharm 中,如何启用将在非启动文件的文件中激活的断点?

python - 使用 SQLAlchemy 从 MySQL 获取最后插入的值

python-3.x - 使用 opencv 和 python 在保存的视频上单击鼠标事件