python - Tensorflow 中的目标必须是一维 Top_k_categorical_accuracy

标签 python tensorflow machine-learning keras deep-learning

我刚刚完成了 Inception V3 CNN 的训练,我正在尝试测量训练数据集的准确性,特别是 top-k 准确性。我从 tensorflow.keras.metrics 调用名为 top_k_categorical_accuracy 的函数,正确排序我的参数 (y_true, y_pred, k) 但我收到一条错误消息我的目标 (y_true) 应该是一维的。但是,当我打印 y_true 的形状(如果我理解正确的话,这是目标)时,我得到 (9000,),对我来说,它似乎是一维的.

两个数组都有一个 dtype = "float32" 因为我在一个线程中读到这导致了问题,但这并不能解决我的问题。

import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator
from keras.applications import InceptionV3
from keras.applications.inception_v3 import preprocess_input

test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

test_generator = test_datagen.flow_from_directory(
    "data/test",
    target_size=(299, 299),
    batch_size=16,
    class_mode="categorical",
    shuffle=True,
    seed=42,
)

STEP_SIZE_TEST = test_generator.n // test_generator.batch_size


model = keras.models.load_model("inceptionv3.hdf5")

results = model.evaluate_generator(test_generator, STEP_SIZE_TEST, workers=8)

y_pred = model.predict_generator(test_generator)
print(y_pred.shape) # Prints (9000, 6)
y_true = test_generator.classes
y_true = y_true.astype("float32") 
print(y_true.shape) #Prints (9000,)

top_k = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=2)

我得到的确切错误是这样的:tensorflow.python.framework.errors_impl.InvalidArgumentError:目标必须是一维的[Op:InTopKV2]

如果我将 y_pred 大小调整为一维数组,则会出现以下错误:tensorflow.python.framework.errors_impl.InvalidArgumentError:预测必须是二维 [Op:InTopKV2]

最佳答案

你试过这个吗?

y_pred = np.argmax(y_pred, axis=1)

据我了解,最后一层有类似 Dense(6,activation='softmax') 的内容。这就是为什么 y_pred 是矩阵。上面的脚本可以提供帮助。

关于python - Tensorflow 中的目标必须是一维 Top_k_categorical_accuracy,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61915461/

相关文章:

machine-learning - Caffe 只训练一个标签

python - (MNIST - GAN)第一次迭代后鉴别器和生成器误差下降到接近于零

python - Tensorflow 中多类分类的类智能精度和召回率?

tensorflow - 使用 keras 和 tensorflow "You must feed a value for placeholder tensor ' input_ 1' with dtype float"

python - Tensorflow: session 之间正确的队列关闭

python - 如何正确地将 tflite_graph.pb 转换为 detect.tflite

azure - 如何在 Microsoft Azure 中使用 SMOTE

python - TypeError: 'list' 对象在使用属性时不可调用

python - 如何使用 urllib2.urlopen 发出没有数据参数的 POST 请求

Python:哪个 XML 解析器支持 DTD !ENTITY 定义?