我使用 TFRecord 来管理我的数据集。
dataset = tf.data.TFRecordDataset(files)
dataset = dataset.map(...)
dataset = dataset.shuffle(...)
dataset = dataset.batch(...)
dataset = dataset.repeat(...)
iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()
网络的输出:
logits_batch = network(image_batch)
我使用tf.metrics来展示性能。
acc_value_op, acc_update_op = tf.metrics.accuracy(labels=label_batch, predictions=predict_batch, name="accuracy")
在 tf.Session() 中我有以下代码:
_, loss_value, g_step, _, summary = sess.run(
[train_op, loss_op, g_step_op, acc_update_op, summary_op],
feed_dict={handle: train_iterator_handle})
acc_value = sess.run(
[acc_value_op],
feed_dict={handle: train_iterator_handle})
我将 acc_update_op 放在 acc_value_op 之前,因为我想先更新metrics.accuracy,然后获取metrics.accuracy 结果。
但是让我困惑的是
1) 这两个 sess.run(...) 实际上会获得两批数据还是只是相同的一批数据?
2)我可以获取一批的最新acc值只需使用
acc_value, _ = sess.run([acc_value_op, acc_update_op], feed_dict={.....})
?
最佳答案
数据集迭代器在运行之间维护状态,因此每次调用 run
时,迭代器都会返回一个新的不同批处理。如果您希望它再次返回第一批,您必须初始化迭代器。
行:
acc_value, _ = sess.run([acc_value_op, acc_update_op], feed_dict={.....})
会给你最新的累计准确度值,它实际上相当于:
acc_value = sess.run(acc_update_op, feed_dict={.....})
因为 acc_update_op
的返回值与 acc_value_op
的返回值相同(参见 tf.metrics.accuracy
)。两者之间的唯一区别是运行第二个将更新内部指标变量,以便下次评估它时它将反射(reflect)累积值。请注意,您可以将运行操作的累积指标重置为零,如下所示:
reset_metrics_op = tf.variables_initializer(tf.get_collection(METRIC_VARIABLES))
如果您想同时获得批处理和累积精度值,则可以使用第二个指标:
batch_acc_value_op, _ = tf.metrics.accuracy(
labels=label_batch, predictions=predict_batch, name="batch_accuracy")
关于python - 我的代码从迭代器获取数据多少次?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51880052/