python - TensorFlow:恢复多个图

标签 python machine-learning tensorflow

假设我们有两个 TensorFlow 计算图,G1G2,保存了权重 W1W2。假设我们通过构建 G1G2 来构建一个新图 G。我们如何为这个新图 G 恢复 W1W2

举个简单的例子:

import tensorflow as tf

V1 = tf.Variable(tf.zeros([1]))
saver_1 = tf.train.Saver()
V2 = tf.Variable(tf.zeros([1]))
saver_2 = tf.train.Saver()

sess = tf.Session()
saver_1.restore(sess, 'W1')
saver_2.restore(sess, 'W2')

在此示例中,saver_1 成功恢复了相应的 V1,但 saver_2 失败并返回了 NotFoundError

最佳答案

您可以使用两个保存程序,每个保存程序只查找其中一个变量。如果您只使用 tf.train.Saver(),我认为它会查找您定义的所有变量。您可以使用 tf.train.Saver([v1, ...]) 为其提供要查找的变量列表。有关详细信息,您可以在此处阅读有关 tf.train.Saver 构造函数的信息:https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops.html#Saver

这是一个简单的工作示例。假设您在文件“save_vars.py”中进行计算,它具有以下代码:

import tensorflow as tf

# Graph 1 - set v1 to have value [1.0]
g1 = tf.Graph()
with g1.as_default():
    v1 = tf.Variable(tf.zeros([1]), name="v1")
    assign1 = v1.assign(tf.constant([1.0]))
    init1 = tf.initialize_all_variables()
    save1 = tf.train.Saver()

# Graph 2 - set v2 to have value [2.0]
g2 = tf.Graph()
with g2.as_default():
    v2 = tf.Variable(tf.zeros([1]), name="v2")
    assign2 = v2.assign(tf.constant([2.0]))
    init2 = tf.initialize_all_variables()
    save2 = tf.train.Saver()

# Do the computation for graph 1 and save
sess1 = tf.Session(graph=g1)
sess1.run(init1)
print sess1.run(assign1)
save1.save(sess1, "tmp/v1.ckpt")

# Do the computation for graph 2 and save
sess2 = tf.Session(graph=g2)
sess2.run(init2)
print sess2.run(assign2)
save2.save(sess2, "tmp/v2.ckpt")

如果您确保您有一个 tmp 目录并运行 python save_vars.py,您将获得保存的检查点文件。

现在,您可以通过以下代码使用名为“restore_vars.py”的文件进行恢复:

import tensorflow as tf

# The variables v1 and v2 that we want to restore
v1 = tf.Variable(tf.zeros([1]), name="v1")
v2 = tf.Variable(tf.zeros([1]), name="v2")

# saver1 will only look for v1
saver1 = tf.train.Saver([v1])
# saver2 will only look for v2
saver2 = tf.train.Saver([v2])
with tf.Session() as sess:
    saver1.restore(sess, "tmp/v1.ckpt")
    saver2.restore(sess, "tmp/v2.ckpt")
    print sess.run(v1)
    print sess.run(v2)

当你运行 python restore_vars.py 时,输出应该是

[1.]
[2.]

(至少在我的电脑上是输出)。如果有任何不清楚的地方,请随时发表评论。

关于python - TensorFlow:恢复多个图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40098743/

相关文章:

python - python 磁盘上的 LRU 缓存

machine-learning - 为什么我在使用 Wakari 'print(digits.images(0))' 数据集时在 'digits' 处收到错误

matlab - 为训练数据创建目标值 - 神经网络

machine-learning - 建议在 Keras 训练中使用 Kfold splits 或validation_split kwarg?

python - 线性模型不支持将字符串转换为 float

tensorflow - tensorflow 中的外积

python - 如何根据已知的html id编写输入数据处理器?

python - 如何在python中将字符串放在文件的前面

javascript - 跨站json rpc : Python server side and Mozilla extension using Javascript client side

python - 我尝试在 tensorflow 数据管道上使用 albumentations 时出错