python - 为什么我们使用 numpy.argmax() 从 numpy 预测数组中返回索引?

标签 python numpy tensorflow machine-learning keras

首先我要说的是,我对神经网络非常陌生,这是我第一次使用 numpy、tensorflow 或 keras。

我使用 MNIST 数据集编写了一个神经网络来识别手写数字。我关注了this tutorial由 Sentdex 发现他正在使用 print(np.argmax(predictions[0])) 打印 numpy 预测数组中的第一个索引。

我尝试运行该程序,将该行替换为 print(predictions[i]),(i 设置为 0),但输出不是数字,而是: [2.1975785e-08 1.8658861e-08 2.8842608e-06 5.7113186e-05 1.2067199e-10 7.2511304e-09 1.6282028e-12 9.9993789e-01 1.3356166e-08 2.0409643e-06]

我感到困惑的代码是:

predictions = model.predict(x_test)
for i in range(10):
   plt.imshow(x_test[i])
   plt.show()
   print("PREDICTION: ", predictions[i])

我阅读了 argmax() 函数的 numpy 文档,据我了解,它接受一个 x 维数组,将其转换为一维数组,然后返回最大值的索引。 model.predict() 的 Keras 文档表明该函数返回网络预测的 numpy 数组。 所以我不明白为什么我们必须使用 argmax() 来正确打印预测,因为据我了解,它有一个完全不相关的目的。

很抱歉代码格式错误,我不知道如何正确地将多行代码块插入到我的帖子中

最佳答案

任何分类神经网络输出的是类别索引上的概率分布,这意味着网络为每个类别分配一个概率。这些概率的总和是 1.0。然后训练网络将最高概率分配给正确的类别,因此要从概率中恢复类别索引,您必须采用具有最大概率的位置(索引)。这是通过 argmax 操作完成的。

关于python - 为什么我们使用 numpy.argmax() 从 numpy 预测数组中返回索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56382596/

相关文章:

Tensorflow 服务在基本路径下找不到可服务的 <MODEL> 版本

python - OpenShift、python 2.7 和带 htaccess 的静态文件

python argparse : allow unregistered arguments

python - 如果 np.arange 喂养,为什么 np.arccos(1.0) 会给出 nan?

python - 将 retrain.py 的输出转换为 tensorflow.js

python - 为分组数据的 RNN 生成具有特定长度的序列/批处理

python - 在 Heroku 上使用环境变量作为凭证

python - 有没有办法在 python 的列表推导中使用两个 if 条件

python - 如何拥有一个没有距离较近的元素对的数组

python - 在 Python/NumPy 中计算平均值的元素排列