从保存的检查点恢复训练模型时,Tensorflow 批量损失峰值?

标签 tensorflow

我遇到了一个奇怪的问题,我一直在尝试调试,但运气不佳。我的模型开始正确训练,批量损失持续减少(从最初的 ~6000 到 20 轮后的 ~120)。但是,当我暂停训练并稍后通过从检查点恢复模型来恢复训练时,批次损失似乎从之前的批次损失(暂停之前)意外飙升,并从更高的损失点恢复下降。我担心的是,当我恢复模型进行评估时,我可能没有使用我认为的经过训练的模型。

与 Tensorflow 教程相比,我已经多次梳理了我的代码。我尝试确保使用教程建议的方法进行保存和恢复。这是代码快照:https://github.com/KaranKash/DigitSpeak/tree/b7dad3128c88061ee374ae127579ec25cc7f5286 - train.py 文件包含保存和恢复步骤、图形设置和训练过程;而 model.py 创建网络层并计算损失。

这是我的打印语句中的一个示例 - 当从 epoch 7 的检查点恢复训练时,注意批量损失急剧上升:

Epoch 6. Batch 31/38. Loss 171.28
Epoch 6. Batch 32/38. Loss 167.02
Epoch 6. Batch 33/38. Loss 173.29
Epoch 6. Batch 34/38. Loss 159.76
Epoch 6. Batch 35/38. Loss 164.17
Epoch 6. Batch 36/38. Loss 161.57
Epoch 6. Batch 37/38. Loss 165.40
Saving to /Users/user/DigitSpeak/cnn/model/model.ckpt
Epoch 7. Batch 0/38. Loss 169.99
Epoch 7. Batch 1/38. Loss 178.42
KeyboardInterrupt
dhcp-18-189-118-233:cnn user$ python train.py
Starting loss calculation...
Found in-progress model. Will resume from there.
Epoch 7. Batch 0/38. Loss 325.97
Epoch 7. Batch 1/38. Loss 312.10
Epoch 7. Batch 2/38. Loss 295.61
Epoch 7. Batch 3/38. Loss 306.96
Epoch 7. Batch 4/38. Loss 290.58
Epoch 7. Batch 5/38. Loss 275.72
Epoch 7. Batch 6/38. Loss 251.12

我已经打印了 inspect_checkpoint.py 脚本的结果。我还尝试了其他损失函数(Adam 和 GradientDescentOptimizer),并注意到恢复训练后峰值损失的相同行为。
dhcp-18-189-118-233:cnn user$ python inspect_checkpoint.py
Optimizer/Variable (DT_INT32) []
conv1-layer/bias (DT_FLOAT) [64]
conv1-layer/bias/Momentum (DT_FLOAT) [64]
conv1-layer/weights (DT_FLOAT) [5,23,1,64]
conv1-layer/weights/Momentum (DT_FLOAT) [5,23,1,64]
conv2-layer/bias (DT_FLOAT) [512]
conv2-layer/bias/Momentum (DT_FLOAT) [512]
conv2-layer/weights (DT_FLOAT) [5,1,64,512]
conv2-layer/weights/Momentum (DT_FLOAT) [5,1,64,512]

最佳答案

我遇到了这个问题,发现这是我在恢复图形时初始化图形变量的事实——丢弃所有学习的参数,用原始图形定义中最初为每个相应张量指定的任何初始化值替换。

例如,如果您使用 tf.global_variable_initializer() 要将变量初始化为模型程序的一部分,无论您的控制逻辑表明将恢复保存的图形,请确保图形恢复流程省略: sess.run(tf.global_variable_initializer())

这对我来说是一个简单但代价高昂的错误,所以我希望其他人能省下几根白发(或一般的头发)。

关于从保存的检查点恢复训练模型时,Tensorflow 批量损失峰值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42029046/

相关文章:

python - 尝试 conda 安装 tensorflow 1.4.1 时发生奇怪的 `glibc==2.17` 冲突

python - 将keras后端从tensorflow cpu更改为gpu

docker - 在Mac OS上,如何使用docker容器编译C++以创建Linux二进制文件

machine-learning - 如何在TensorFlow中打印CNN训练过程中每个epoch的准确率?

python - Tensorflow 2 抛出 ValueError : as_list() is not defined on an unknown TensorShape

python - 使用 TensorFlow 训练神经网络时出错

tensorflow - 在 Google Colab 中保存 TensorFlow 检查点

python - DropoutWrapper 在运行中是不确定的?

python - 在内存中序列化和反序列化 Tensorflow 模型并继续训练

python - 访问内部张量并向 tflite 模型添加新节点?