tensorflow - 使用 Estimator API 更新批量归一化均值和方差

标签 tensorflow machine-learning batch-normalization tensorflow-estimator

文档对此并不是 100% 清楚:

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:

(参见 https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization )

这是否意味着保存 moving_meanmoving_variance 所需的全部内容如下?

def model_fn(features, labels, mode, params):
   training = mode == tf.estimator.ModeKeys.TRAIN
   extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

   x = tf.reshape(features, [-1, 64, 64, 3])
   x = tf.layers.batch_normalization(x, training=training)

   # ...

  with tf.control_dependencies(extra_update_ops):
     train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())

换句话说,只需使用

with tf.control_dependencies(extra_update_ops):

要注意保存moving_meanmoving_variance吗?

最佳答案

是的,添加这些控制依赖项将保存均值和方差。

关于tensorflow - 使用 Estimator API 更新批量归一化均值和方差,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49204810/

相关文章:

machine-learning - 多个子特征的特征提取

python - Python 中是否有 quadprog() 的替代方案

python - 仅在恢复模型时输出中出现 TensorFlow NaN

python - 如何使用 matplotlib (python) colah 的变形网格进行绘图?

machine-learning - 为什么仅在 CNN 中对 channel 进行批量归一化

python - 为什么tf.layers.batch_normalization的参数 'scale'在下一层是relu时被禁用?

python - tensorflow 的实现比 torch 的慢 2 倍

Python TensorFlow - 什么是 tf.flags.FLAGS?

python - 如何在使用 ImageDataGenerator 时使用 to_categorical

python - 现实与预测之间的延迟差距