python - 在 tensorflow 中训练后如何使用模型(保存/加载图表)

标签 python machine-learning neural-network tensorflow deep-learning

我的tensorflow版本是0.11。 我想在训练后保存一张图,或者保存其他东西让 tensorflow 可以加载它。

我/使用导出和导入元图

我已经读过这篇文章: Tensorflow: how to save/restore a model?

我的 Save.py文件:

X = tf.placeholder("float", [None, 28, 28, 1], name='X')
Y = tf.placeholder("float", [None, 10], name='Y')

tf.train.Saver()
with tf.Session() as sess:
     ...run something ...
     final_tensor = tf.nn.softmax(py_x, name='final_result')
     tf.add_to_collection("final_tensor", final_tensor)

     predict_op = tf.argmax(py_x, 1)
     tf.add_to_collection("predict_op", predict_op)

saver.save(sess, 'my_project') 

然后我运行 load.py:

with tf.Session() as sess:
   new_saver = tf.train.import_meta_graph('my_project.meta')
   new_saver.restore(sess, 'my_project')
   predict_op = tf.get_collection("predict_op")[0]
   for i in range(2):
        test_indices = np.arange(len(teX)) # Get A Test Batch
        np.random.shuffle(test_indices)
        test_indices = test_indices[0:test_size]

        print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
                         sess.run(predict_op, feed_dict={"X:0": teX[test_indices],
                                                         "p_keep_conv:0": 1.0,
                                                         "p_keep_hidden:0": 1.0})))

但它返回错误

Traceback (most recent call last):
  File "load_05_convolution.py", line 62, in <module>
    "p_keep_hidden:0": 1.0})))
  File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 717, in run
    run_metadata_ptr)
  File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 894, in _run
    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (256, 784) for Tensor u'X:0', which has shape '(?, 28, 28, 1)'

我真的不知道为什么?

如果我添加 final_tensor = tf.get_collection("final_result")[0]

它返回另一个错误:

Traceback (most recent call last):
  File "load_05_convolution.py", line 46, in <module>
    final_tensor = tf.get_collection("final_result")[0]
IndexError: list index out of range

是因为 tf.add_to_collection 只包含一个占位符吗?

II/使用 tf.train.write_graph

我将这一行添加到 save.py 的末尾 tf.train.write_graph(graph, '文件夹', 'train.pb')

它成功地创建了文件'train.pb'

我的 load.py :

with tf.gfile.FastGFile('folder/train.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
  predict_op = sess.graph.get_tensor_by_name('predict_op:0')
  for i in range(2):
        test_indices = np.arange(len(teX)) # Get A Test Batch
        np.random.shuffle(test_indices)
        test_indices = test_indices[0:test_size]

        print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
                         sess.run(predict_op, feed_dict={"X:0": teX[test_indices],
                                                         "p_keep_conv:0": 1.0,
                                                         "p_keep_hidden:0": 1.0})))

然后返回错误:

Traceback (most recent call last):
  File "load_05_convolution.py", line 22, in <module>
    graph_def.ParseFromString(f.read())
  File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/message.py", line 185, in ParseFromString
    self.MergeFromString(serialized)
  File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1085, in MergeFromString
    raise message_mod.DecodeError('Unexpected end-group tag.')
google.protobuf.message.DecodeError: Unexpected end-group tag.

您介意分享保存/加载模型的标准方法、代码或教程吗?我真的很困惑。

最佳答案

您的第一个解决方案(使用 MetaGraph)几乎可以工作,但出现错误是因为您将一批扁平化 MNIST 训练示例提供给 tf.placeholder()期望一批 MNIST 训练示例作为形状为 batch_size x height (= 28) x width (= 28) 的 4-D 张量x channel (= 1)。解决此问题的最简单方法是 reshape 输入数据。而不是这个声明:

print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
                 sess.run(predict_op, feed_dict={
                     "X:0": teX[test_indices],
                     "p_keep_conv:0": 1.0,
                     "p_keep_hidden:0": 1.0})))

...请尝试以下语句,它会适本地 reshape 您的输入数据:

print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
                 sess.run(predict_op, feed_dict={
                     "X:0": teX[test_indices].reshape(-1, 28, 28, 1),
                     "p_keep_conv:0": 1.0,
                     "p_keep_hidden:0": 1.0})))

关于python - 在 tensorflow 中训练后如何使用模型(保存/加载图表),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40956040/

相关文章:

python - 如何在视频中添加音频?

python - 分别循环遍历数据框的每一列

python - 如何通过引用python中pandas数据框中另一列的值来添加新列

r - 如何在 R 中为决策树模型创建增益图?

python - 后向差分编码

python-3.x - Pytorch 中缺乏 L1 正则化的稀疏解决方案

neural-network - 如何选择每个卷积层的过滤器数量?

python - 递归查找混合类型元组中的最大值

java - 无法写入核心转储,Java 运行时环境检测到 fatal error

c++ - OpenCV 3/神经网络/预测错误/Ptr< ANN_MLP >/C++