python - 如何在 tensorflow 中恢复 session ?

标签 python machine-learning tensorflow

<分区>

我想在不再次训练网络的情况下使用我的神经网络。 我读过

save_path = saver.save(sess, "model.ckpt")
print("Model saved in file: %s" % save_path)

现在文件夹中有 3 个文件:checkpointmodel.ckptmodel.ckpt.meta

我想在 python 的另一个类中恢复数据,获取我的神经网络的权重并进行单个预测。

我该怎么做?

最佳答案

要保存模型,您可以这样做:

model_checkpoint = 'model.chkpt'

# Create the model
...
...

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    # Create a saver so we can save and load the model as we train it
    tf_saver = tf.train.Saver(tf.all_variables())

    # (Optionally) do some training of the model
    ...
    ...

    tf_saver.save(sess, model_checkpoint)

我假设您已经这样做了,因为您已经获得了三个文件。 当你想在另一个类中加载模型时,你可以这样做:

# The same file as we saved earlier
model_checkpoint = 'model.chkpt'

# Create the SAME model as before
...
...

with tf.Session() as sess:
    # Restore the model
    tf_saver = tf.train.Saver()
    tf_saver.restore(sess, model_checkpoint)

    # Now your model is loaded with the same values as when you saved,
    #   and you can do prediction or continue training

关于python - 如何在 tensorflow 中恢复 session ?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41037650/

相关文章:

python - 在 2x2 网格中绘制形状图

python - 我可以使用 Scikit learn 和 autopy 来玩电子游戏吗?

python - scikit kmeans 不准确的成本\惯性

python - 凯拉斯导入错误 : cannot import name initializations

python - 识别python中的动词时态

python - 声音库 python

python - 尽管是一个多行字符串,为什么这个 ASCII 艺术字不打印在多行上?

python - 值错误 : Cannot reshape a tensor (BERT - transfer learning)

python - 在keras中使用conv2D层时,在tf.random.set_seed中设置种子是否还会设置glorot_uniform kernel_initializer使用的种子吗?

python - 无法安装tensorflow=1.0.0