python - 使用 keras ImageDataGenerator 时如何在测试阶段对图像应用标准化?

标签 python machine-learning keras image-preprocessing

我正在尝试使用经过训练的模型来预测新图像。我的准确率是95%。但无论我输入什么,predict_classes 总是返回第一个标签 [0]。 我想原因之一是我在 ImageDataGenerator 中使用了 featurewise_center=Truesamplewise_center=True 。我想我应该对我的输入图像做同样的事情。但我找不到这些函数对图像做了什么。

如有任何建议,我们将不胜感激。

ImageDataGenerator代码:

train_datagen = ImageDataGenerator(
samplewise_center=True,
rescale=1. / 255,
shear_range=30,
zoom_range=30,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2)

test_datagen = ImageDataGenerator(
samplewise_center=True,
rescale=1. / 255,
shear_range=30,
zoom_range=30,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)

预测代码(我使用100*100*3图像来训练模型):

model = load_model('CNN_model.h5')
img = cv2.imread('train/defect/6.png')
img = cv2.resize(img,(100,100))
img = np.reshape(img,[1,100,100,3])
img = img/255.

classes = model.predict_classes(img)

print (classes)

11/14 更新:

我更改了代码来预测图像,如下所示。但即使我输入用于训练模型的图像(并且准确率达到 95%),模型仍然会预测同一类。有什么我错过的吗?

model = load_model('CNN_model.h5')
img = cv2.imread('train/defect/6.png')
img = cv2.resize(img,(100,100))
img = np.reshape(img,[1,100,100,3])
img = np.array(img, dtype=np.float64) 
img = train_datagen.standardize(img)

classes = model.predict_classes(img)
print(classes)

最佳答案

您需要使用ImageDataGenerator实例的standardize()方法。来自 Keras documentation :

standardize

standardize(x)

Applies the normalization configuration to a batch of inputs.

Arguments

  • x: Batch of inputs to be normalized.

Returns

The inputs, normalized.

所以它会是这样的:

img = cv2.imread('train/defect/6.png')
img = cv2.resize(img,(100,100))
img = np.reshape(img,[1,100,100,3])
img = train_datagen.standardize(img)

classes = model.predict_classes(img)

请注意,它也会应用重新缩放,因此无需自己执行此操作(即删除 img = img/255.)。

此外,请记住,由于您设置了 featurewise_ceneter=True,因此您需要使用 fit()使用生成器进行训练之前的方法:

train_datagen.fit(training_data)

# then use fit_generator method
model.fit_generator(train_datagen, ...)

关于python - 使用 keras ImageDataGenerator 时如何在测试阶段对图像应用标准化?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53186129/

相关文章:

python - 将 pandas 数据帧传递给 fastapi

optimization - 非凸损失函数

python - 在 Django 模型中不存在的 django-rest-framework 序列化器上定义任意字段

machine-learning - 选择 GeForce 或 Quadro GPU 通过 TensorFlow 进行机器学习

python - 具有哈希向量化器精度的 Scikit SGD 分类器停留在 58%

python - 训练 DQN 时 Q 值爆炸

python - 在 tensorflow 模型中随机选择层

python - 在循环中更改模型(keras,python)

python - 判断函数是否是嵌套函数

python - 我如何在 Django 中定义 3 个以上模型之间的多对多关系?