tensorflow - 如何在 TensorFlow 中编码标签?

标签 tensorflow

我需要将我的字符串标签转换为像 [0, 0, ... , 1, ... 0] 这样的向量。
据我所知,这是一种称为一个热向量的东西。
我有 10 个类,所以有 10 个不同的字符串标签。

任何人都可以帮助进行直接和逆变换吗?
我是 tensorflow 的新手,所以请善待。

最佳答案

向前的方向很容易,因为有 tf.one_hot操作:

import tensorflow as tf

original_indices = tf.constant([1, 5, 3])
depth = tf.constant(10)
one_hot_encoded = tf.one_hot(indices=original_indices, depth=depth)

with tf.Session():
  print(one_hot_encoded.eval())

输出:
[[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]]

反过来也不错,tf.where找到非零索引:
def decode_one_hot(batch_of_vectors):
  """Computes indices for the non-zero entries in batched one-hot vectors.

  Args:
    batch_of_vectors: A Tensor with length-N vectors, having shape [..., N].
  Returns:
    An integer Tensor with shape [...] indicating the index of the non-zero
    value in each vector.
  """
  nonzero_indices = tf.where(tf.not_equal(
      batch_of_vectors, tf.zeros_like(batch_of_vectors)))
  reshaped_nonzero_indices = tf.reshape(
      nonzero_indices[:, -1], tf.shape(batch_of_vectors)[:-1])
  return reshaped_nonzero_indices

with tf.Session():
  print(decode_one_hot(one_hot_encoded).eval())

打印:
[1 5 3]

关于tensorflow - 如何在 TensorFlow 中编码标签?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42092508/

相关文章:

tensorflow - 在linux上构建tensorflow时出错

python - 如何支持自定义 Tensorflow 层中的混合精度?

tensorflow - tf.contrib.rnn.static_rnn - 'Tensor' 对象不可迭代

python - TensorFlow:加载模型并对图像进行前向传递

python - 二维数据的 Keras 内置 MSE 损失返回二维矩阵,而不是标量损失

python - Tensorflow GetNext() 失败,因为迭代器尚未初始化

python - TensorFlow 1.5.0-rc0 : error using `tf.app.flags`

python - 使用自定义方法保存/加载 Keras 模型

python - TensorFlow 运行时错误 : MetaGraphDef associated with tags serve could not be found in SavedModel

python-3.x - 当我执行 "rasa init"错误 : "failed to install native tensorflow runtime"