python - 我的代码从迭代器获取数据多少次?

标签 python python-3.x tensorflow iterator dataset

我使用 TFRecord 来管理我的数据集。

dataset = tf.data.TFRecordDataset(files)
dataset = dataset.map(...)
dataset = dataset.shuffle(...)
dataset = dataset.batch(...)
dataset = dataset.repeat(...)
iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()

网络的输出:

logits_batch = network(image_batch)

我使用tf.metrics来展示性能。

acc_value_op, acc_update_op = tf.metrics.accuracy(labels=label_batch, predictions=predict_batch, name="accuracy")

在 tf.Session() 中我有以下代码:

_, loss_value, g_step, _, summary = sess.run(
    [train_op, loss_op, g_step_op, acc_update_op, summary_op],
    feed_dict={handle: train_iterator_handle})
acc_value = sess.run(
    [acc_value_op],
    feed_dict={handle: train_iterator_handle})

我将 acc_update_op 放在 acc_value_op 之前,因为我想先更新metrics.accuracy,然后获取metrics.accuracy 结果

但是让我困惑的是

1) 这两个 sess.run(...) 实际上会获得两批数据还是只是相同的一批数据?

2)我可以获取一批的最新acc值只需使用

acc_value, _ = sess.run([acc_value_op, acc_update_op], feed_dict={.....})

最佳答案

数据集迭代器在运行之间维护状态,因此每次调用 run 时,迭代器都会返回一个新的不同批处理。如果您希望它再次返回第一批,您必须初始化迭代器。

行:

acc_value, _ = sess.run([acc_value_op, acc_update_op], feed_dict={.....})

会给你最新的累计准确度值,它实际上相当于:

acc_value = sess.run(acc_update_op, feed_dict={.....})

因为 acc_update_op 的返回值与 acc_value_op 的返回值相同(参见 tf.metrics.accuracy )。两者之间的唯一区别是运行第二个将更新内部指标变量,以便下次评估它时它将反射(reflect)累积值。请注意,您可以将运行操作的累积指标重置为零,如下所示:

reset_metrics_op = tf.variables_initializer(tf.get_collection(METRIC_VARIABLES))

如果您想同时获得批处理和累积精度值,则可以使用第二个指标:

batch_acc_value_op, _ = tf.metrics.accuracy(
    labels=label_batch, predictions=predict_batch, name="batch_accuracy")

关于python - 我的代码从迭代器获取数据多少次?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51880052/

相关文章:

python - Numpy:随机分割数组

python - 从 OSX 为 Windows 创建 Kivy 包

python - 地址的正则表达式在 Regex 101 (Python) 中工作,而不是在 Python 中使用 re.match?

python-3.x - Tensorflow 2.1 全内存和 tf.function 调用两次

python - tf.shape() 在 tensorflow 中得到错误的形状

python - 如何在 Keras 中试验自定义二维卷积核?

python - TensorFlow解码_csv形状错误

python - 列表理解之谜 - Python

python - 按照 init 中显示的顺序表示一个类,无需硬编码

python - 如何使用 float64 nan 选择行?