python - 如何解释 Keras 中 model.predict() 的输出

标签 python numpy tensorflow machine-learning keras

当我尝试执行预测图像时,我的代码出现问题。使用keras等

我正在寻找如何输出数组的方法

例如[1,0,0]然后输出rock

import numpy as np
from google.colab import files
from keras.preprocessing import image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.applications.vgg16 import VGG16

%matplotlib inline

uploaded = files.upload()

for fn in uploaded.keys():
 
  # predicting images
  path = fn
  img = image.load_img(path, target_size=(150,150))
  imgplot = plt.imshow(img)
  x = image.img_to_array(img)
  x = np.expand_dims(x, axis=0)
  x = preprocess_input(x)

  #images = np.vstack([x])
  classes = model.predict(x, batch_size=10)
  print(classes)

  print(fn)
  if classes==[[1,0,0]]:
    print('paper')
  else:
    print('rock')

然后是这样的输出

Saving 0a3UtNzl5Ll3sq8K.png to 0a3UtNzl5Ll3sq8K (4).png
[[1. 0. 0.]]
0a3UtNzl5Ll3sq8K.png
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-69-863494647f7a> in <module>()
     28 
     29   print(fn)
---> 30   if classes==[[1,0,0]]:
     31     print('paper')
     32   else:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

截图程序: enter image description here

最佳答案

始终检查您正在使用的对象的类型。

返回类型是张量数组,不是列表;它实际上是每个标签的一系列概率。为了将其转换为 numpy 数组,您需要使用 prediction.numpy()

在您的情况下,混淆来自这样一个事实,即第一个标签的概率确实为 100%,而其余标签的概率为 0%。

除此之外,还要注意你比较的方式:

[[1. 0. 0.]][[1,0,0]]

您需要使用 argmax() 才能正确获取标签。

关于python - 如何解释 Keras 中 model.predict() 的输出,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62726643/

相关文章:

python - 如何只从特定的 gmail 标签下载未读的附件?

python - Numpy einsum 二维数组的外和

python - 分类到二进制 - 我做错了什么?

tensorflow - 使用 tf.data.Dataset 评估性能的最佳方式

python - 反向传播算法陷入训练 AND 函数的困境

python - 基于数据框过滤数据透视表

python - Homebrew python 安装选项

python - 在 python 方法中处理异常的正确方法是什么?

Python numpy 保留已排序二维数组的索引列表

tensorflow keras嵌入LSTM