我想做的是同时运行多个预训练的 Tensorflow 网络。因为每个网络内部的一些变量的名称可以相同,所以常见的解决方案是在创建网络时使用名称范围。然而,问题是我已经训练了这些模型并将训练的变量保存在几个检查点文件中。在创建网络时使用名称范围后,我无法从检查点文件加载变量。
例如,我训练了一个 AlexNet,我想比较两组变量,一组来自 epoch 10(保存在文件 epoch_10.ckpt 中),另一组来自 epoch 50(保存在文件 epoch_10.ckpt 中)文件 epoch_50.ckpt)。因为这两个是完全一样的net,所以里面的变量名是完全一样的。我可以使用
创建两个网络with tf.name_scope("net1"):
net1 = CreateAlexNet()
with tf.name_scope("net2"):
net2 = CreateAlexNet()
但是,我无法从 .ckpt 文件加载经过训练的变量,因为当我训练这个网络时,我没有使用名称范围。尽管我可以在训练网络时将名称范围设置为“net1”,但这会阻止我加载 net2 的变量。
我试过:
with tf.name_scope("net1"):
mySaver.restore(sess, 'epoch_10.ckpt')
with tf.name_scope("net2"):
mySaver.restore(sess, 'epoch_50.ckpt')
这行不通。
解决这个问题的最佳方法是什么?
最佳答案
最简单的解决方案是创建不同的 session ,为每个模型使用单独的图表:
# Build a graph containing `net1`.
with tf.Graph().as_default() as net1_graph:
net1 = CreateAlexNet()
saver1 = tf.train.Saver(...)
sess1 = tf.Session(graph=net1_graph)
saver1.restore(sess1, 'epoch_10.ckpt')
# Build a separate graph containing `net2`.
with tf.Graph().as_default() as net2_graph:
net2 = CreateAlexNet()
saver2 = tf.train.Saver(...)
sess2 = tf.Session(graph=net1_graph)
saver2.restore(sess2, 'epoch_50.ckpt')
如果由于某种原因这不起作用,并且您必须使用单个 tf.Session
(例如,因为您想要在另一个 TensorFlow 计算中组合来自两个网络的结果),最好解决方案是:
- 像您已经在做的那样在名称范围内创建不同的网络,并且
- 单独创建
tf.train.Saver
两个网络的实例,带有一个额外的参数来重新映射变量名称。
当 constructing保存者,您可以将字典作为 var_list
参数传递,将检查点中的变量名称(即没有名称范围前缀)映射到 tf.Variable
对象您在每个模型中创建的。
您可以通过编程方式构建var_list
,并且您应该能够执行如下操作:
with tf.name_scope("net1"):
net1 = CreateAlexNet()
with tf.name_scope("net2"):
net2 = CreateAlexNet()
# Strip off the "net1/" prefix to get the names of the variables in the checkpoint.
net1_varlist = {v.name.lstrip("net1/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
net1_saver = tf.train.Saver(var_list=net1_varlist)
# Strip off the "net2/" prefix to get the names of the variables in the checkpoint.
net2_varlist = {v.name.lstrip("net2/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
net2_saver = tf.train.Saver(var_list=net2_varlist)
# ...
net1_saver.restore(sess, "epoch_10.ckpt")
net2_saver.restore(sess, "epoch_50.ckpt")
关于python - 同时运行多个预训练的 Tensorflow 网络,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39175945/