python - 您能解释一下具有 BatchNormalization 的神经网络中的 Keras get_weights() 函数吗?

标签 python keras

当我在 Keras 中运行神经网络(没有 BatchNormalization)时,我了解 get_weights() 函数如何提供神经网络的权重和偏差。然而,使用 BatchNorm 它会产生 4 个额外参数,我假设是 Gamma、Beta、Mean 和 Std。

当我保存这些值时,我尝试手动复制一个简单的神经网络,但无法让它们产生正确的输出。有谁知道这些值是如何运作的?

No Batch Norm

With Batch Norm

最佳答案

我将举一个例子来解释简单的多层感知器(MLP)和具有批量归一化(BN)的MLP的情况下的get_weights()。

示例:假设我们正在处理 MNIST 数据集,并使用 2 层 MLP 架构(即 2 个隐藏层)。隐藏层 1 中的神经元数量为 392,隐藏层 2 中的神经元数量为 196。因此,我们的 MLP 的最终架构将为 784 x 512 x 196 x 10

这里784是输入图像维度,10是输出层维度

案例1:没有批量归一化的MLP => 让我的模型名称为使用ReLU激活函数的model_relu。现在,在训练model_relu之后,我正在使用 get_weights(),这将返回一个大小为 6 的列表,如下面的屏幕截图所示。

get_weights() with simple MLP and without Batch Norm列表值如下:

  • (784, 392):隐藏层 1 的权重
  • (392,):与隐藏层 1 的权重相关的偏差

  • (392, 196):隐藏层 2 的权重

  • (196,):与隐藏层2的权重相关的偏差

  • (196, 10):输出层的权重

  • (10,):与输出层权重相关的偏差

案例2:带有批量归一化的MLP => 让我的模型名称为model_batch,它也使用ReLU激活函数和批量归一化。现在,在训练 model_batch 后,我正在使用 get_weights(),这将返回一个大小为 14 的列表,如下面的屏幕截图所示。

get_weights() with Batch Norm 列表值如下:

  • (784, 392):隐藏层 1 的权重
  • (392,):与隐藏层 1 的权重相关的偏差
  • (392,) (392,) (392,) (392,):这四个参数是 gamma、beta、mean 和 std。大小为 392 的 dev 值,每个值与隐藏层 1 的批量归一化相关联。

  • (392, 196):隐藏层 2 的权重

  • (196,):与隐藏层 2 的权重相关的偏差
  • (196,) (196,) (196,) (196,):这四个参数是 gamma、beta、运行平均值和 std。大小为 196 的 dev,每个与隐藏层 2 的批量归一化相关。

  • (196, 10):输出层的权重

  • (10,):与输出层权重相关的偏差

因此,在 case2 中,如果你想获取隐藏层 1、隐藏层 2 和输出层的权重,Python 代码可以是这样的:

wrights = model_batch.get_weights()      
hidden_layer1_wt = wrights[0].flatten().reshape(-1,1)     
hidden_layer2_wt = wrights[6].flatten().reshape(-1,1)     
output_layer_wt = wrights[12].flatten().reshape(-1,1)

希望这有帮助!

Ref: keras-BatchNormalization

关于python - 您能解释一下具有 BatchNormalization 的神经网络中的 Keras get_weights() 函数吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57087273/

相关文章:

python - 将二维多项式系数附加到 Python 中的项的更快方法?

python - TensorFlow - 当损失达到定义值时停止训练

python - 名称错误 : name 'image' is not defined

machine-learning - Keras 有状态 RNN 数据分批的正确方法

python - GAE Blobstore : upload blob along with other text fields

python - 使用 lxml 和 Python 解析嵌套的 xml

python - 打印列表项和索引

python - 双向 LSTM 的问题

machine-learning - 使用不同长度的时间序列在 Keras 中训练 LSTM

python - 从 AWS Lambda 连接 Oracle