import tensorflow as tf
saver = tf.train.Saver()
saver.restore(...)
但是 saver.restore 只有恢复整个图的选项。我只想恢复特定范围内的那些变量。
提前致谢!
最佳答案
假设您在 InceptionV1
范围内拥有 Google 的 InceptionNet 模型,并且您想要加载它,但要重新训练范围 InceptionRetrained
中的最后一层除外。
假设您已经开始重新训练最后一层,并且您通过 saver2.save(session, 'last_layer.ckpt')
创建了 last_layer.ckpt 文件,下面是如何从两个检查点恢复网络。
saver1 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionV1'))
saver1.restore(session, 'inception_model_from_google.ckpt')
saver2 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionRetrained'))
saver2.restore(session, 'last_layer.ckpt')
如果您只重新训练最后一层,请不要忘记通过使用 var_list
参数调用优化器来禁用梯度在网络上的传播(节省时间)。
tf.train.Optimizer(0.0001).minimize(
loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Inceptionretrained'))
关于tensorflow - 如何从 tensorflow 中保存的检查点恢复特定范围的变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42546365/