python - 如何恢复预训练模型以初始化参数

标签 python tensorflow

我已经下载了一个带有预训练模型的网络。我在网络中添加了几个层和参数,我想使用这个预训练模型来初始化原始参数,并自己随机初始化新添加的参数。我使用此代码:

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "output/saver-test")
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

但是我遇到了错误:“Key global_step not found in checkpoint”,这个错误是因为我有一些在预训练模型中不存在的新参数。但是我该如何解决这个问题呢?更何况,我想使用这段代码“sess.run(tf.global_variables_initializer())”来初始化新添加的参数,但是从预训练模型中提取的参数会被它覆盖吗?

最佳答案

发生这种情况是因为您的网络与加载的网络不完全匹配。 您可以使用类似的选择性检查点加载程序:

  reader = tf.train.NewCheckpointReader(os.path.join(checkpoint_dir, ckpt_name))
    restore_dict = dict()
    for v in tf.trainable_variables():
        tensor_name = v.name.split(':')[0]
        if reader.has_tensor(tensor_name):
            print('has tensor ', tensor_name)
            restore_dict[tensor_name] = v

    restore_dict['my_new_var_scope/my_new_var'] = self.get_my_new_var_variable()

其中 get_my_new_var_variable() 是这样的:

    def get_my_new_var_variable(self):
    with tf.variable_scope("my_new_var_scope",reuse=tf.AUTO_REUSE):
        my_new_var = tf.get_variable("my_new_var", dtype=tf.int32,initializer=tf.constant([23, 42]))
    return my_new_var

加载权重:

self.saver = tf.train.Saver(restore_dict)
    self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))

编辑:

请注意,为了避免覆盖加载的变量,您可以使用此方法:

def initialize_uninitialized(sess):
  global_vars = tf.global_variables()
  is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
  not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
  if len(not_initialized_vars):
    sess.run(tf.variables_initializer(not_initialized_vars))

或者在加载变量之前简单地调用 tf.global_variables_initializer() 应该在这里工作。

关于python - 如何恢复预训练模型以初始化参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52532150/

相关文章:

python - 调用 object.method() 和 Class.method(object) 时幕后发生了什么?

python - ValueError : Error when checking input: expected conv2d_1_input to have shape (224, 224, 1) 但得到形状为 (224, 224, 8) 的数组

python - 如何记录 TensorFlow 变量中的各个标量值?

tensorflow - Keras 模型在微调时变得更糟

python - 类中的字典和 Lambda?

python - 如何使用python向mysql数据库插入变量? "You have an error in your SQL syntax"

machine-learning - Tensorflow:启动新 session 时出现扭矩和 GPU 问题:CUDA_ERROR_INVALID_DEVICE

python - 如何循环打印模型的所有 tf​​.Tensors?

python - Pymongo:选择字段到列表中,而不将字段名称添加到列表中

python - 使用多处理器模块中的 Manager 更新嵌套字典