使用 Tensorflow,是否有输出网络预测的方法?
我的输出一直在为 12 个类使用 One Hot Representation
[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0]
etc...
当尝试从我的模型中获取给定输入的预测时,我运行了以下代码
prediction=tf.argmax(y,1)
best = sess.run([prediction],feed_dict={x: batch_x, y: batch_y,
seqlen: batch_seqlen})
print("Prediction: ")
print(best)
当我运行这段代码并打印预测时,我的输出是:
[array([1, 5, 7, 7, 7, 4, 7, 9, 4, 4, 9, 6, 7, 8, 3, 2], dtype=int64)]
我输入的批量大小是 16,所以有 16 个输出确实有意义。然而,这些都不是 OneHot 表示(不确定 tensorflow 的输出是否打算作为索引进行交互,所以 1 实际上是某种形式的 onehot
有没有一种方法可以让每个特定的 X 创建一个预测排名列表,模型在给定 X 的情况下最有可能找到什么?
这有意义吗?
最佳答案
您正在获取 1-hot 向量的 tf.argmax
,所以这就是您看到索引而不是您预期的 1-hot 向量的原因。
要获得分类预测的排序列表,您可以获取预测向量并应用 values, indices = tf.nn.top_k(prediction)
values
将是您的预测按降序排列,indices
将是那些排序的 values
索引。
关于python - 生成预测列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53180481/