<分区>
我想在不再次训练网络的情况下使用我的神经网络。 我读过
save_path = saver.save(sess, "model.ckpt")
print("Model saved in file: %s" % save_path)
现在文件夹中有 3 个文件:checkpoint
、model.ckpt
和 model.ckpt.meta
我想在 python 的另一个类中恢复数据,获取我的神经网络的权重并进行单个预测。
我该怎么做?
<分区>
我想在不再次训练网络的情况下使用我的神经网络。 我读过
save_path = saver.save(sess, "model.ckpt")
print("Model saved in file: %s" % save_path)
现在文件夹中有 3 个文件:checkpoint
、model.ckpt
和 model.ckpt.meta
我想在 python 的另一个类中恢复数据,获取我的神经网络的权重并进行单个预测。
我该怎么做?
最佳答案
要保存模型,您可以这样做:
model_checkpoint = 'model.chkpt'
# Create the model
...
...
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
# Create a saver so we can save and load the model as we train it
tf_saver = tf.train.Saver(tf.all_variables())
# (Optionally) do some training of the model
...
...
tf_saver.save(sess, model_checkpoint)
我假设您已经这样做了,因为您已经获得了三个文件。 当你想在另一个类中加载模型时,你可以这样做:
# The same file as we saved earlier
model_checkpoint = 'model.chkpt'
# Create the SAME model as before
...
...
with tf.Session() as sess:
# Restore the model
tf_saver = tf.train.Saver()
tf_saver.restore(sess, model_checkpoint)
# Now your model is loaded with the same values as when you saved,
# and you can do prediction or continue training
关于python - 如何在 tensorflow 中恢复 session ?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41037650/
相关文章:
python - 我可以使用 Scikit learn 和 autopy 来玩电子游戏吗?
python - scikit kmeans 不准确的成本\惯性
python - 凯拉斯导入错误 : cannot import name initializations
python - 尽管是一个多行字符串,为什么这个 ASCII 艺术字不打印在多行上?
python - 值错误 : Cannot reshape a tensor (BERT - transfer learning)
python - 在keras中使用conv2D层时,在tf.random.set_seed中设置种子是否还会设置glorot_uniform kernel_initializer使用的种子吗?