我想使用tf.metrics.accuracy
跟踪我的预测的准确性,但我不确定如何使用函数返回的 update_op(acc_update_op
):
accuracy, acc_update_op = tf.metrics.accuracy(labels, predictions)
我认为将其添加到 tf.GraphKeys.UPDATE_OPS
中会有意义,但我不确定如何执行此操作。
最佳答案
tf.metrics.accuracy
是众多流式度量 TensorFlow 操作之一(其中另一个是 tf.metrics.recall
)。创建后,将创建两个变量(count
和 total
),以便累积所有传入结果以获得一个最终结果。第一个返回值是用于计算计数/总计
的张量。返回的第二个操作是更新这些变量的有状态函数。在评估分类器对多批数据的性能时,流式度量函数非常有用。一个简单的使用示例:
# building phase
with tf.name_scope("streaming"):
accuracy, acc_update_op = tf.metrics.accuracy(labels, predictions)
test_fetches = {
'accuracy': accuracy,
'acc_op': acc_update_op
}
# when testing the classifier
with tf.name_scope("streaming"):
# clear counters for a fresh evaluation
sess.run(tf.local_variables_initializer())
for _i in range(n_batches_in_test):
fd = get_test_batch()
outputs = sess.run(test_fetches, feed_dict=fd)
print("Accuracy:", outputs['accuracy'])
I was thinking that adding it to
tf.GraphKeys.UPDATE_OPS
would make sense, but I am not sure how to do this.
除非您仅将 UPDATE_OPS 集合用于测试目的,否则这不是一个好主意。通常,集合已经具有用于训练阶段的某些控制操作(例如移动批量归一化参数),这些操作并不意味着与验证阶段一起运行。最好将它们保留在新集合中或手动将这些操作添加到获取字典中。
关于machine-learning - 如何使用 tf.metrics.accuracy?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46787174/