python - 如何从经过训练的 Tensorflow 分类器获得类别预测?

标签 python tensorflow machine-learning

我已经训练了二元分类器模型。模型类包含 self.cost 、 self.initial_state 、 self.final_state 和 self.logits 参数。它只需使用 tf.train.Saver 即可保存:

saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
saver.save(session, 'model.ckpt')

模型训练完成后,我将其加载为:

with tf.variable_scope("Model", reuse=False):
    model = MODEL(config, is_training=False)

with tf.Session() as session:
    saver = tf.train.Saver(tf.global_variables())
    saver.restore(session, 'model.ckpt')

但是,我的 model.run 函数返回交叉熵损失,这是图中的最后一个操作。我不需要损失,我需要每个批处理元素的模型预测

logits = tf.sigmoid(tf.nn.xw_plus_b(last_layer, self.output_w, self.output_b))

其中last_layer是一个800x1矩阵,然后我将其 reshape 为32x25x1(batch_size,sequence_length,1)矩阵。该矩阵包含 [0-1] 范围内的模型预测值。

那么,如何使用该模型对单元素矩阵 1x1x1 进行预测?

最佳答案

添加计算准确性所需的操作,类似于我在下面复制的内容(只需从我手头上最接近的模型中复制出来)。

  self.logits_flat = tf.argmax(logits, axis=1, output_type=tf.int32)
  labels_flat = tf.argmax(labels, axis=1, output_type=tf.int32)
  accuracy = tf.cast(tf.equal(self.logits_flat, labels_flat), tf.float32, name='accuracy')

现在,当您运行模型时(无论是在测试还是训练期间),请为 sess.run 调用添加准确性:

sess.run([train_op, accuracy], feed_dict=...)

sess.run([accuracy, logits], feed_dict=...)

当你调用sess.run时,你所做的就是告诉tensorflow计算你所要求的值。您需要将执行这些计算所需的任何数据传递给它。 Tensorflow 是惰性的,它不会执行任何未明确需要生成您请求的结果的计算。例如。如果您运行上面列出的 sess.run 的第二个版本,优化器将不会运行,因此您的权重将不会更新。

请注意,您可以在网络训练后添加OP,因为它们实际上都不会添加任何变量,因此它们不会影响保存/恢复过程。

关于python - 如何从经过训练的 Tensorflow 分类器获得类别预测?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49303232/

相关文章:

php - Python 的 None 单例在 PHP 中的等价物是什么?

python - 难以理解如何正确地将 Perl 匹配转换为 Python

python-3.x - Keras: TypeError: run() 得到了一个意外的关键字参数 'kernel_regularizer'

python-3.x - 给予 GAN 的随机噪声应该保持恒定吗?

r - 比较测试性能

python - 返回最后评估的对象的 `any` 的替代方案?

python - 根据多个条件对 python 列表进行排序

machine-learning - 如何在 Keras 中反转 LSTM 输入的形状

r - 如何在 R 中计算 KNN 变量重要性

python - 用于文本分类的 Tensorflow 模型未按预期进行预测?