tensorflow - 如何从 tensorflow 中保存的检查点恢复特定范围的变量?

标签 tensorflow neural-network deep-learning

import tensorflow as tf
saver = tf.train.Saver() 
saver.restore(...)

但是 saver.restore 只有恢复整个图的选项。我只想恢复特定范围内的那些变量。

提前致谢!

最佳答案

假设您在 InceptionV1 范围内拥有 Google 的 InceptionNet 模型,并且您想要加载它,但要重新训练范围 InceptionRetrained 中的最后一层除外。

假设您已经开始重新训练最后一层,并且您通过 saver2.save(session, 'last_layer.ckpt') 创建了 last_layer.ckpt 文件,下面是如何从两个检查点恢复网络。

saver1 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionV1'))
saver1.restore(session, 'inception_model_from_google.ckpt')

saver2 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionRetrained'))
saver2.restore(session, 'last_layer.ckpt')

如果您只重新训练最后一层,请不要忘记通过使用 var_list 参数调用优化器来禁用梯度在网络上的传播(节省时间)。

tf.train.Optimizer(0.0001).minimize(
            loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Inceptionretrained'))

关于tensorflow - 如何从 tensorflow 中保存的检查点恢复特定范围的变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42546365/

相关文章:

Python TLearn - 损失太高

machine-learning - 神经网络 - 我应该删除所有派生/计算变量吗?

classification - 构建平均图像文件时出错(Caffe)

python-3.x - 如何得到逻辑回归的正确答案?

machine-learning - 多类分割的广义骰子损失: keras implementation

python - 值错误: expected 2D or 3D input (got 1D input) PyTorch

python - Tensorflow 添加了一个新的操作,无法从 python 导入

python - tensorflow中评估指标的含义

tensorflow - 使用带有分布式 TF 的 tf.data API

c# - 用于 future 预测的 .NET 神经网络或 AI