python - 在同一 Tensorflow session 中从 Saver 加载两个模型

标签 python tensorflow

我有两个网络:一个生成输出的 Model 和一个对输出评分的 Adversary

两者都是单独训练的,但现在我需要在一个 session 中合并它们的输出。

我已经尝试实现这篇文章中提出的解决方案:Run multiple pre-trained Tensorflow nets at the same time

我的代码

with tf.name_scope("model"):
    model = Model(args)
with tf.name_scope("adv"):
    adversary = Adversary(adv_args)

#...

with tf.Session() as sess:
    tf.global_variables_initializer().run()

    # Get the variables specific to the `Model`
    # Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
    model_varlist = {v.name.lstrip("model/")[:-2]: v 
                     for v in tf.global_variables() if v.name[:5] == "model"}
    model_saver = tf.train.Saver(var_list=model_varlist)
    model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
    model_saver.restore(sess, model_ckpt.model_checkpoint_path)

    # Get the variables specific to the `Adversary`
    adv_varlist = {v.name.lstrip("avd/")[:-2]: v 
                   for v in tf.global_variables() if v.name[:3] == "adv"}
    adv_saver = tf.train.Saver(var_list=adv_varlist)
    adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
    adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)

问题

对函数 model_saver.restore() 的调用似乎什么也没做。在另一个模块中,我使用了一个带有 tf.train.Saver(tf.global_variables()) 的保护程序,它可以很好地恢复检查点。

模型有 model.tvars = tf.trainable_variables()。为了检查发生了什么,我使用 sess.run() 在恢复前后提取了 tvars。每次使用初始随机分配的变量并且不分配来自检查点的变量。

关于为什么 model_saver.restore() 似乎什么都不做有什么想法吗?

最佳答案

解决这个问题花了很长时间,所以我发布了我可能不完美的解决方案,以防其他人需要它。

为了诊断问题,我手动遍历了每个变量并一一分配给它们。然后我注意到在分配变量后名称会改变。此处对此进行了描述:TensorFlow checkpoint save and read

根据那篇文章中的建议,我在各自的图表中运行了每个模型。这也意味着我必须在其自己的 session 中运行每个图表。这意味着以不同的方式处理 session 管理。

首先我创建了两个图表

model_graph = tf.Graph()
with model_graph.as_default():
    model = Model(args)

adv_graph = tf.Graph()
with adv_graph.as_default():
    adversary = Adversary(adv_args)

然后两个session

adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)

然后我在每个 session 中初始化变量并分别恢复每个图

with sess.as_default():
    with model_graph.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)

with adv_sess.as_default():
    with adv_graph.as_default():
        tf.global_variables_initializer().run()
        adv_saver = tf.train.Saver(tf.global_variables())
        adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
        adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)

从这里开始,每当需要每个 session 时,我都会在该 session 中使用 with sess.as_default(): 包装任何 tf 函数。最后我手动关闭 session

sess.close()
adv_sess.close()

关于python - 在同一 Tensorflow session 中从 Saver 加载两个模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41607144/

相关文章:

python - 使用 Selenium Python 滚动到底部

tensorflow - 为什么 'tf.python_io.TFRecordWriter' 在 TensorFlow 中如此缓慢和存储消耗?

python - 如何在 Keras 中添加常量张量?

python - 如何使用 Python 从 Selenium 的重定向链中获取中间 URL?

python - lua cjson 无法解码特定的 unicode 字符?

python - np.where( ) 不改变所有的值

Python Regex 循环非捕获组

python - 将 [28,28,2] matlab 数组转换为 [2, 28, 28, 1] 张量

python - 检查值是否包含在张量中

python - Tensorflow ReLu 不起作用?