python - 同时运行多个预训练的 Tensorflow 网络

标签 python tensorflow

我想做的是同时运行多个预训练的 Tensorflow 网络。因为每个网络内部的一些变量的名称可以相同,所以常见的解决方案是在创建网络时使用名称范围。然而,问题是我已经训练了这些模型并将训练的变量保存在几个检查点文件中。在创建网络时使用名称范围后,我无法从检查点文件加载变量。

例如,我训练了一个 AlexNet,我想比较两组变量,一组来自 epoch 10(保存在文件 epoch_10.ckpt 中),另一组来自 epoch 50(保存在文件 epoch_10.ckpt 中)文件 epoch_50.ckpt)。因为这两个是完全一样的net,所以里面的变量名是完全一样的。我可以使用

创建两个网络
with tf.name_scope("net1"):
    net1 = CreateAlexNet()
with tf.name_scope("net2"):
    net2 = CreateAlexNet()

但是,我无法从 .ckpt 文件加载经过训练的变量,因为当我训练这个网络时,我没有使用名称范围。尽管我可以在训练网络时将名称范围设置为“net1”,但这会阻止我加载 net2 的变量。

我试过:

with tf.name_scope("net1"):
    mySaver.restore(sess, 'epoch_10.ckpt')
with tf.name_scope("net2"):
    mySaver.restore(sess, 'epoch_50.ckpt')

这行不通。

解决这个问题的最佳方法是什么?

最佳答案

最简单的解决方案是创建不同的 session ,为每个模型使用单独的图表:

# Build a graph containing `net1`.
with tf.Graph().as_default() as net1_graph:
  net1 = CreateAlexNet()
  saver1 = tf.train.Saver(...)
sess1 = tf.Session(graph=net1_graph)
saver1.restore(sess1, 'epoch_10.ckpt')

# Build a separate graph containing `net2`.
with tf.Graph().as_default() as net2_graph:
  net2 = CreateAlexNet()
  saver2 = tf.train.Saver(...)
sess2 = tf.Session(graph=net1_graph)
saver2.restore(sess2, 'epoch_50.ckpt')

如果由于某种原因这不起作用,并且您必须使用单个 tf.Session(例如,因为您想要在另一个 TensorFlow 计算中组合来自两个网络的结果),最好解决方案是:

  1. 像您已经在做的那样在名称范围内创建不同的网络,并且
  2. 单独创建tf.train.Saver两个网络的实例,带有一个额外的参数来重新映射变量名称。

constructing保存者,您可以将字典作为 var_list 参数传递,将检查点中的变量名称(即没有名称范围前缀)映射到 tf.Variable 对象您在每个模型中创建的。

您可以通过编程方式构建var_list,并且您应该能够执行如下操作:

with tf.name_scope("net1"):
  net1 = CreateAlexNet()
with tf.name_scope("net2"):
  net2 = CreateAlexNet()

# Strip off the "net1/" prefix to get the names of the variables in the checkpoint.
net1_varlist = {v.name.lstrip("net1/"): v
                for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
net1_saver = tf.train.Saver(var_list=net1_varlist)

# Strip off the "net2/" prefix to get the names of the variables in the checkpoint.
net2_varlist = {v.name.lstrip("net2/"): v
                for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
net2_saver = tf.train.Saver(var_list=net2_varlist)

# ...
net1_saver.restore(sess, "epoch_10.ckpt")
net2_saver.restore(sess, "epoch_50.ckpt")

关于python - 同时运行多个预训练的 Tensorflow 网络,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39175945/

相关文章:

python - 为列表中的不同数字切换列表中的字符

python - 我可以使用嵌套的 for 循环遍历列表中的完整字符串吗?

python - DJANGO 在保存时添加 NULL - IntegrityError at/****/(1048, "***' 不能为 null")

python - 如何使用 NDB map 生成游标

python - 如何将 listA 列 1 值匹配并替换为与 ListB 列 1 匹配的 ListB 列 2 值,就像我们在 vlookup 中所做的那样

tensorflow - 从Anaconda环境中卸载TensorFlow

tensorflow - 在卡住的 Keras 模型中,dropout 层是否仍然处于事件状态(即 trainable=False)?

python - TF2.0中如何将tf.data.Dataset类型的数据切片到一定长度?

python - 如何让 tf.data.Dataset.from_tensor_slices 接受我的 dtype?

tensorflow - TensowFlow GradientDescentOptimizer 在这个例子中做了什么?