python - 在预测期间,数据规范化如何在 keras 中工作?

标签 python machine-learning tensorflow neural-network keras

我看到 imageDataGenerator 允许我指定不同样式的数据规范化,例如featurewise_center、samplewise_center 等

我从示例中看到,如果我指定了这些选项之一,那么我需要在生成器上调用 fit 方法,以允许生成器计算像生成器上的平均图像这样的统计数据。

(X_train, y_train), (X_test, y_test) = cifar10.load_data()
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)

datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(X_train)

# fits the model on batches with real-time data augmentation:
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=32),
                samples_per_epoch=len(X_train), nb_epoch=nb_epoch)

我的问题是,如果我在训练期间指定了数据归一化,预测将如何工作?我看不到在框架中我什至会如何传递训练集均值/标准偏差的知识来预测以允许我自己规范化我的测试数据,但我也没有在训练代码中看到这些信息在哪里存储。

标准化所需的图像统计信息是否存储在模型中,以便在预测期间使用?

最佳答案

是的 - 这是 Keras.ImageDataGenerator 的一个非常大的缺点您无法自己提供标准化统计信息。但是 - 有一个简单的方法可以解决这个问题。

假设你有一个函数 normalize(x)这是对图像 batch 进行规范化(请记住,生成器提供的不是简单图像而是图像数组 - batch 形状为 (nr_of_examples_in_batch, image_dims ..),您可以使用使用规范化:

def gen_with_norm(gen, normalize):
    for x, y in gen:
        yield normalize(x), y

那么你可以简单地使用 gen_with_norm(datagen.flow, normalize)而不是 datagen.flow .

此外-您可能会恢复meanstdfit 计算通过从 datagen 中的适当字段(例如 datagen.meandatagen.std )获取它的方法。

关于python - 在预测期间,数据规范化如何在 keras 中工作?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41855512/

相关文章:

python - Tensorflow:在不运行任何 session 的情况下将 Tensor 转换为 numpy 数组

python - 基本 tensorflow 问题(输入和输出数组)

javascript - 如何使用正则表达式提取 JavaScript 变量

python - predict_proba 或 decision_function 作为估计器 "confidence"

当文本中有空格时python不写

python - 使用 sklearn 数字数据集预测数字 - 错误

python - 如何向使用 Keras 构建的 CNN 模型添加第二个输入参数(第一个是图像)?

python - tensorflow - 输入到 StringToHashBucketFast 操作类型错误

python - 导入错误 : No Module named simplejson

python - 按包含字符串的最长元素过滤列表