python - Keras BatchNormalization,究竟什么是样本标准化?

标签 python neural-network keras

我想弄清楚 Keras 中的批量归一化究竟做了什么。现在我有以下代码。

for i in range(8):
    c = Convolution2D(128, 3, 3, border_mode = 'same', init = 'he_normal')(c)
    c = LeakyReLU()(c)
    c = Convolution2D(128, 3, 3, border_mode = 'same', init = 'he_normal')(c)
    c = LeakyReLU()(c)
    c = Convolution2D(128, 3, 3, border_mode = 'same', init = 'he_normal')(c)
    c = LeakyReLU()(c)
    c = merge([c, x], mode = 'sum')
    c = BatchNormalization(mode = 1)(c)
    x = c

我根据 Keras 文档 1: sample-wise normalization 将 batch norm mode 设置为 1。此模式假定 2D 输入。

我认为这应该做的只是独立于每个其他样本对批处理中的每个样本进行归一化。但是,当我查看调用函数的源代码时,我看到了以下内容。

    elif self.mode == 1:
        # sample-wise normalization
        m = K.mean(x, axis=-1, keepdims=True)
        std = K.std(x, axis=-1, keepdims=True)
        x_normed = (x - m) / (std + self.epsilon)
        out = self.gamma * x_normed + self.beta

在这里它只是计算所有 x 的平均值,在我的例子中是 (BATCH_SIZE, 128, 56, 56) 我认为。我认为它应该在模式 1 下独立于批处理中的其他样本进行归一化。那么 axis = 1 不应该吗?另外,文档中的“假定 2D 输入”是什么意思?

最佳答案

In this it is just computing the mean over all of x which in my case is (BATCH_SIZE, 128, 56, 56) I think.

这样做你已经违反了该层的契约(Contract)。这不是 2 维输入,而是 4 维输入。

I thought it was supposed to normalize independent of the other samples in the batch when in mode 1

确实如此。 K.mean(..., axis=-1) 减少轴 -1,它与输入的最后一个轴同义。因此,假设输入形状为 (batchsz, features),轴 -1 将是 features 轴。

由于 K.meannumpy.mean 非常相似,您可以自己测试一下:

>>> x = [[1,2,3],[4,5,6]]
>>> x
array([[1, 2, 3],
       [4, 5, 6]])
>>> np.mean(x, axis=-1)
array([ 2.,  5.])

您可以看到批处理中每个样本的特征都减少了。

关于python - Keras BatchNormalization,究竟什么是样本标准化?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37625272/

相关文章:

python - 小数据集上的 CNN 过度拟合

python - 带有 Tensorflow (1.3) 后端的 Keras (2.0.8) 占用所有可用内存

python - 是否可以将带有无关元素的字典传递给 Django object.create 方法?

machine-learning - 使用keras训练多类nn时,loss无法进一步往下走,可能是什么原因

neural-network - OpenAI 健身房的月球着陆器模型未收敛

machine-learning - 如何记录或查看使用 Dropout 训练 TensorFlow 神经网络时所用的成本?

python - 池化后预期 Keras 形状不匹配

python - 来自 Python 项目的 RPM 子包

python - 在 QCompleter 激活调用时查找 QStandardItemModel 的索引

python - 如何在python中计算skipgrams?