python - 如何使用以前训练过的模型来获取图像标签 - TensorFlow

标签 python python-2.7 machine-learning neural-network tensorflow

我训练了一个模型(根据 MNIST tutorial )并保存了它:

saver = tf.train.Saver()
save_path = saver.save(sess,'/path/to/model.ckpt')

我想使用保存的模型来为新一批图像找到标签。我加载模型并使用数据库对其进行测试:

# load MNIST data
folds = build_database_tuple.load_data(data_home_dir='/path/to/database')

# starting the session. using the InteractiveSession we avoid build the entiee comp. graph before starting the session
sess = tf.InteractiveSession()

# start building the computational graph
...

BUILD AND DEFINE ALL THE LAYERS

...

y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# TRAIN AND EVALUATION:
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
# Restore variables from disk.
savepath = '/path/to/model.ckpt'
saver.restore(sess, save_path=savepath)
print("Model restored.")

print("test accuracy %g"%accuracy.eval(feed_dict={x: folds.test.images, y_: folds.test.labels, keep_prob: 1.0}))

虽然我可以加载和测试模型,但如何获取包含数据库图像预测的 y' 数组?

我浏览了网络并找到了很多关于这个问题的答案,但我无法将这些答案与这个特定案例相匹配。比如我找到this answer关于 CIFAR10 教程,但它与 MNIST 教程有很大不同。

最佳答案

定义一个用于执行分类的 OP,例如

predictor = tf.argmax(y_conv,1)

然后用新的输入在训练好的模型上运行

print(sess.run(predictor, feed_dict={ x = new_data }))

因为“预测器”不依赖于 y,您不必提供它,它仍然会执行。

如果您只想查看对测试图像的预测,您也可以通过删除准确性评估调用并执行以下操作在一次运行调用中完成这两项操作

acc, predictions = sess.run([accuracy, predictor],
                            feed_dict={x: folds.test.images,
                                       y_: folds.test.labels,
                                       keep_prob: 1.0}))

print('Accuracy', acc)
print('Predictions', predictions)

关于python - 如何使用以前训练过的模型来获取图像标签 - TensorFlow,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40060983/

相关文章:

python - 如何使用多处理模块迭代列表并将其与字典中的键匹配?

python - 如何在 python 中将字符串拆分为多个部分?

python - 使用网站的Python Youtube音频下载器

machine-learning - 从最终模型混淆矩阵重新创建 "multiClassSummary"统计数据

python - 我需要弄清楚如何创建用户定义的函数来计算三个数字的平均值

python - 如何使用 PyQt 实现 QLCDNumbers?

python - 如何用下划线替换大写字母?

python - Django Rest Framework - 如何路由到函数 View

python - UserWarning : Starting from version 2. 2.1,macOS 发行轮中的库文件由 Apple Clang (Xcode_8.3.3) 编译器构建

matlab - LibSVM 成本权重对于不平衡数据不起作用