python - 为什么 TensorFlow 示例在增加批量大小时会失败?

标签 python tensorflow

我正在查看 Tensorflow MNIST example for beginners并发现在这部分:

for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

将批量大小从 100 更改为 204 以上会导致模型无法收敛。它最高可达 204,但在 205 和我尝试过的任何更高数字时,准确度最终会低于 10%。这是错误、算法问题还是其他问题?

这是为 OS X 运行他们的二进制安装,似乎是版本 0.5.0。

最佳答案

您在初学者示例中使用的是非常基本的线性模型吗?

这里有一个调试它的技巧 - 在增加批量大小时观察交叉熵(第一行来自示例,第二行是我刚刚添加的):

cross_entropy = -tf.reduce_sum(y_*tf.log(y))
cross_entropy = tf.Print(cross_entropy, [cross_entropy], "CrossE")

批量大小为 204 时,您会看到:

I tensorflow/core/kernels/logging_ops.cc:64] CrossE[92.37558]
I tensorflow/core/kernels/logging_ops.cc:64] CrossE[90.107414]

但是在 205,您会从一开始就看到这样的序列:

I tensorflow/core/kernels/logging_ops.cc:64] CrossE[472.02966]
I tensorflow/core/kernels/logging_ops.cc:64] CrossE[475.11697]
I tensorflow/core/kernels/logging_ops.cc:64] CrossE[1418.6655]
I tensorflow/core/kernels/logging_ops.cc:64] CrossE[1546.3833]
I tensorflow/core/kernels/logging_ops.cc:64] CrossE[1684.2932]
I tensorflow/core/kernels/logging_ops.cc:64] CrossE[1420.02]
I tensorflow/core/kernels/logging_ops.cc:64] CrossE[1796.0872]
I tensorflow/core/kernels/logging_ops.cc:64] CrossE[nan]

Ack - NaN 出现了。基本上,大批量正在创建如此巨大的梯度,以至于您的模型正在失控——它应用的更新太大,并且大大超出了它应该去的方向。

在实践中,有几种方法可以解决这个问题。您可以将学习率从 0.01 降低到例如 0.005,这会导致最终精度为 0.92。

train_step = tf.train.GradientDescentOptimizer(0.005).minimize(cross_entropy)

或者您可以使用更复杂的优化算法(Adam、Momentum 等)尝试做更多的事情来确定梯度的方向。或者,您可以使用具有更多自由参数的更复杂的模型来分散大梯度。

关于python - 为什么 TensorFlow 示例在增加批量大小时会失败?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33641799/

相关文章:

python - 如何查看用户是否不是多对多关系?

python - 扩展不可变(卡住)数据类

python - 保留使用 TFIDF 制作的模型,以便使用 Scikit for Python 预测新内容

python - Keras/Tensorflow 中带有 if 语句的自定义损失函数

python - 使用 `tensorflow.python.keras.estimator.model_to_estimator` 将 Keras 模型转换为 Estimator API 时如何通知类权重?

python - SQLalchemy Primaryjoin 属性在 Python 3.4 中不起作用(使用 Python 2.7 起作用)

python - 如何发现数据集中的哪些特征具有预测性?

javascript - 尝试将 TensorFlow 保存的模型转换为 TensorFlow.js 模型时出错

c++ - dlib 19.6 多分类器训练数据

python - 关于占位符形状的 tensorflow 错误