python - TensorFlow - 如果存在则恢复

标签 python tensorflow

是否有可能恢复一个变量,只有当它存在时?这样做最惯用的方法是什么?

例如,考虑以下最小示例:

import tensorflow as tf
import glob
import sys
import os

with tf.variable_scope('volatile'):
    x = tf.get_variable('x', initializer=0)

with tf.variable_scope('persistent'):
    y = tf.get_variable('y', initializer=0)
    add1 = tf.assign_add(y, 1)

saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'persistent'))

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
tf.get_default_graph().finalize()

print('save file', sys.argv[1])
if glob.glob(sys.argv[1] + '*'):    
    saver.restore(sess, sys.argv[1])

print(sess.run(y))
sess.run(add1)
print(sess.run(y))
saver.save(sess, sys.argv[1])

当使用相同的参数运行两次时,程序首先打印 0\n1,然后按预期打印 1\n2。现在假设您更新代码以获得新功能,方法是在 persistent 范围。当存在旧的保存文件时再次运行此命令将中断以下内容:

NotFoundError (see above for traceback): Key persistent/z not found in checkpoint
     [[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_INT32],
         _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0,
                                                                 save/RestoreV2_1/tensor_names,
                                                                 save/RestoreV2_1/shape_and_slices)]]
     [[Node: save/Assign_1/_18 = _Recv[client_terminated=false,
         recv_device="/job:localhost/replica:0/task:0/device:GPU:0",
         send_device="/job:localhost/replica:0/task:0/device:CPU:0",
         send_device_incarnation=1,
         tensor_name="edge_12_save/Assign_1",
         tensor_type=DT_FLOAT,
         _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

最佳答案

您可以使用以下函数进行恢复(取自 here ):

def optimistic_restore(session, save_file, graph=tf.get_default_graph()):
    reader = tf.train.NewCheckpointReader(save_file)
    saved_shapes = reader.get_variable_to_shape_map()
    var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
            if var.name.split(':')[0] in saved_shapes])    
    restore_vars = []    
    for var_name, saved_var_name in var_names:            
        curr_var = graph.get_tensor_by_name(var_name)
        var_shape = curr_var.get_shape().as_list()
        if var_shape == saved_shapes[saved_var_name]:
            restore_vars.append(curr_var)
    opt_saver = tf.train.Saver(restore_vars)
    opt_saver.restore(session, save_file)

我通常运行 sess.run(tf.global_variables_initializer()) 以确保所有变量都已初始化,然后我运行 optimistic_restore(sess,...) 恢复可以恢复的变量。

关于python - TensorFlow - 如果存在则恢复,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47997203/

相关文章:

python - 找不到满足要求的版本 Pillow==2.7.0

python - Pandas to_numeric(错误 ='coerce' )不会将无效值转换为 nan

python - 高效查找具有割断的邻居并返回索引

tensorflow - 预训练的 Tensorflow 模型 RGB -> RGBY channel 扩展

tensorflow - tf.estimator.Estimator.train() 是否保持 input_fn 状态

python - 为什么我的人工神经网络不学习?

python - 匹配以逗号分隔的精确长度的所有单词

python - Kaggle 泰坦尼克号与 tflearn 神经网络

python - 在 M1 Mac 中安装 Tensorflow

python - 使用 Keras 稀疏分类交叉熵进行逐像素多类分类