python - 如何使用 tf.one_hot 计算一种热编码?

标签 python tensorflow

我正在尝试使用tensorflow构建mnist数据集的y_train的单热编码。我不明白该怎么做?

# unique values 0 - 9
y_train = array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

keras中,我们会做类似的事情

# this converts it into one hot encoding
one hot_encoding = tf.keras.utils.to_categorical(y_train)

tf.one_hot中,我应该向indices深度参数输入什么?完成一次热编码后,如何将其从 2d-tensor 转换回 numpy 数组?

最佳答案

我不熟悉 Tensorflow,但经过一些测试,这是我发现的:

tf.one_hot() 接受一个索引和一个深度索引是实际转换为one-hot编码的值。 深度是指要利用的最大值。

例如,采用以下代码:

y = [1, 2, 3, 2, 1]
tf.keras.utils.to_categorical(y)
sess = tf.Session();
with sess.as_default():
    print(tf.one_hot(y, 2).eval())
    print(tf.one_hot(y, 4).eval())
    print(tf.one_hot(y, 6).eval())

tf.keras.utils.to_categorical(y) 返回以下内容:

array([[0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.]], dtype=float32)

相比之下,tf.one_hot() 选项(2、4 和 6)执行以下操作:

[[0. 1.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 1.]]
[[0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]]
[[0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]]

从这里可以看出,使用 tf.one_hot() 来模拟 tf.keras.utils.to_categorical()深度参数应等于数组中存在的最大值,+1 表示 0。在本例中,最大值为 3,因此编码中有四个可能的值 - 0、1、2 和 3。因此,需要深度为 4 来表示 one-hot 编码中的所有这些值。

至于转换为 numpy,如上所示,使用 Tensorflow session ,在张量上运行 eval() 将其转换为 numpy 数组。有关执行此操作的方法,请参阅 How can I convert a tensor into a numpy array in TensorFlow? .

我不熟悉 Tensorflow,但希望这会有所帮助。

注意:对于 MNIST 而言,深度 10 就足够了。

关于python - 如何使用 tf.one_hot 计算一种热编码?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56695125/

相关文章:

python - 在 python 中应用 PMML 预测器模型

python - 如何使用 pycparser 在函数声明中查找函数名称?

python - CNN : images and parameters, 的多个输入如何合并

python - Tensorflow - 使用自定义比较器对张量进行排序

python - 尝试在 Python 3.6 中导入 Keras 时出错

python - 使用 tensorflow-gpu 1.14 和 tf.distribute.MirroredStrategy() 的自定义训练循环导致 ValueError

python - ctypes 不允许多次取消引用指针

python - 简单的 pythoncurses-application 在运行时使用 100% CPU。这是正常的吗?

python - 根据列表有条件地创建组框

tensorflow - 如何在spyder上使用tensorflow?