python - 如何加载经过训练的 tensorflow 模型

标签 python tensorflow jupyter-notebook

我在加载 tensorflow 模型来测试一些新数据时遇到了各种麻烦。当我训练模型时,我使用了这个:

save_model_file = 'my_saved_model'
saver = tf.train.Saver()
save_path = saver.save(sess, save_model_file)

这似乎会导致创建以下文件:

my_saved_model.meta
checkpoint
my_saved_model.index
my_saved_model.data-00000-of-00001

我不知道我应该注意哪些文件。

现在模型已经训练完毕,我似乎无法在不引发异常的情况下加载或使用它。这是我正在做的事情:

def neural_net_data_input(data_shape):
    theshape=(None,)+tuple(data_shape)
    return tf.placeholder(tf.float32,shape=theshape,name='x')

def neural_net_label_input(n_out):
    return tf.placeholder(tf.float32,shape=(None,n_out),name='one_hot_labels')

def neural_net_keep_prob_input(): 
    return tf.placeholder(tf.float32,name='keep_prob')

def do_generate_network(x):
    #
    # here is where i generate the network layer by layer.
    # this code works fine so i am not showing it here
    #
    pass

#
# Now I want to restore the model
#
tf.reset_default_graph()

input_data_shape=(32,32,1)
final_num_outputs=43

graph1 = tf.Graph()
with graph1.as_default():
    x = neural_net_data_input(input_data_shape)
    one_hot_labels = neural_net_label_input(final_num_outputs)
    keep_prob=neural_net_keep_prob_input()
    logits = do_generate_network(x)
    # Name logits Tensor, so that is can be loaded from disk after training
    logits = tf.identity(logits, name='logits')
    #
    # accuracy: we use this for validation testing
    #
    correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(one_hot_labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')

################################
# Evaluate
################################

new_data=myutils.load_pickle_file(SOME_DATA_FILE_NAME)
new_features=new_data['features']
new_one_hot_labels=new_data['labels']

print('Evaluating on new data...')
with tf.Session(graph=graph1) as sess:
    # Initializing the variables
    sess.run(tf.global_variables_initializer())

    saver.restore(sess,save_model_file)
    new_acc = sess.run(accuracy, feed_dict={x: new_features, one_hot_labels: new_one_hot_labels, keep_prob: 1.})
    print('Testing Accuracy For New Images: {}'.format(new_acc))

但是当我这样做时,我得到了这个:

TypeError: Cannot interpret feed_dict key as Tensor: The name 'save/Const:0' refers to a Tensor which does not exist. The operation, 'save/Const', does not exist in the graph.

所以,我尝试在 session 中移动我的图表,如下所示:

################################
# Evaluate
################################

print('Evaluating on web data...')
with tf.Session() as sess:

    x = neural_net_data_input(input_data_shape)
    one_hot_labels = neural_net_label_input(final_num_outputs)
    keep_prob=neural_net_keep_prob_input()
    logits = do_generate_network(x)
    # Name logits Tensor, so that is can be loaded from disk after training
    logits = tf.identity(logits, name='logits')
    #
    # accuracy: we use this for validation testing
    #
    correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(one_hot_labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')

    sess.run(tf.global_variables_initializer())

    my_save_dir="/home/carnd/CarND-Traffic-Sign-Classifier-Project"
    load_model_meta_file=os.path.join(my_save_dir,"my_saved_model.meta")
    load_model_path=os.path.join(my_save_dir,"my_saved_model")
    new_saver = tf.train.import_meta_graph(load_model_meta_file)
    new_saver.restore(sess, load_model_path)

    web_acc = sess.run(accuracy, feed_dict={x: web_features, one_hot_labels: web_one_hot_labels, keep_prob: 1.})
    print('Testing Accuracy For Web Images: {}'.format(web_acc))

现在它运行时不会抛出错误,但它打印的准确度结果是 0.02!我输入的数据与训练期间准确率高达 95% 的数据完全相同。所以看来我以某种方式错误地加载了我的模型。

我做错了什么?

最佳答案

加载训练模型的步骤:

  1. 加载图表: 您可以使用 tf.train.import_meta_graph() 加载图表。示例代码如下:

    model_path = "my_saved_model"
    inference_graph = tf.Graph()
    with tf.Session(graph= inference_graph) as sess:
       # Load the graph with the trained states
      loader = tf.train.import_meta_graph(model_path+'.meta')
      loader.restore(sess, model_path)
    
  2. 获取张量:使用get_tensor_by_name()获取推理所需的张量。因此,在您的模型中,请确保按名称命名张量,以便您可以在推理过程中调用它。

      #Get the tensors by their variable name 
    
      _accuracy = inference_graph.get_tensor_by_name('accuracy:0')
      _x  = inference_graph get_tensor_by_name('x:0')
      _y  = inference_graph.get_tensor_by_name('y:0')
    
  3. 测试:可以通过使用加载的张量来完成。 sess.run(_accuracy, feed_dict={_x: ... , _y:...}

关于python - 如何加载经过训练的 tensorflow 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45744301/

相关文章:

python - 如何在 python matplotlib 中将 Z 轴映射到 3D 图形上

python - Pandas 列表的列以分隔行

r - 输入形状(以喀拉斯计)(此损失期望目标与输出具有相同的形状)

python - 连接到远程机器上现有的 Jupyter 服务器(内核)

python - Pyspark 将多个 csv 文件读入数据框(或 RDD?)

python - BeautifulSoup:获取类文本

python - 从urllib.request向HTTPServer发出许多并发请求时的神秘异常

TensorFlow 将数据加载到 tf.Dataset 所需的时间太长

python - Tensorflow:简单图像图像分类器在时代之间根本不更新

无法从 jupyterhub/jupyter notebook 调用 tensorflow gpu,为什么?