python - 过滤 Tensorflow 数据集中的 NaN 值

标签 python tensorflow tensorflow2.0 tensorflow-datasets

是否有一种简单的方法可以从 tensorflow.data.Dataset 实例中过滤所有包含 nan 值的条目?喜欢 Pandas 中的 dropna 方法吗?


简短示例:

import numpy as np
import tensorflow as tf

X = tf.data.Dataset.from_tensor_slices([[1,2,3], [0,0,0], [np.nan,np.nan,np.nan], [3,4,5], [np.nan,3,4]])
y = tf.data.Dataset.from_tensor_slices([np.nan, 0, 1, 2, 3])
ds = tf.data.Dataset.zip((X,y))
ds = foo(ds)  # foo(x) = ?
for x in iter(ds): print(str(x))

我可以为 foo(x) 使用什么来获得以下输出:

(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>)
(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([3., 4., 5.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)

如果你想自己尝试,here is Google Colab notebook .

最佳答案

我的方法与现有答案略有不同。我没有使用 sum,而是使用了 tf.reduce_any:

filter_nan = lambda x, y: not tf.reduce_any(tf.math.is_nan(x)) and not tf.math.is_nan(y)

ds = tf.data.Dataset.zip((X,y)).filter(filter_nan)

list(ds.as_numpy_iterator())
[(array([0., 0., 0.], dtype=float32), 0.0),
 (array([3., 4., 5.], dtype=float32), 2.0)]

关于python - 过滤 Tensorflow 数据集中的 NaN 值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64157389/

相关文章:

python - 图像字幕给出较弱的结果

python - Keras 模型到tensorflow.keras

tensorflow - "Anaconda can not spawn a new process..."我有崇高的文字错误

tensorflow - 使用 tf.layers 时替代 arg_scope

python - tensorflow.keras.preprocessing.text.Tokenizer 中的文本编码与旧的 tfds.deprecated.text.TokenTextEncoder 有何不同

tensorflow - 如何使 keras 模型采用(无,)张量作为输入

python - 如何限制正则表达式的 findall() 方法

python - 为什么这个函数也返回 "None None"?

python - 连接类似的 pandas DataFrame 列,对它们进行排序并用 np.NaN 填充

python - TF2/Keras 切片张量使用 [ :, :, 0]