Tensorflow Dataset API 在完成一个 epoch 后恢复迭代器

标签 tensorflow tensorflow-datasets

我有 190 个特征和标签,我的批量大小为 20,但经过 9 次迭代 tf.reshape 返回异常 reshape 的输入是具有 21 个值的张量,但请求的形状有60,我知道这是由于 Iterator.get_next()。我如何恢复我的迭代器,以便它再次从头开始提供批处理服务?

最佳答案

如果您想重新启动tf.data.Iterator从其数据集的开头,考虑使用可初始化迭代器,它有一个可以运行来重新初始化迭代器的操作:

dataset = ...  # A `tf.data.Dataset` instance.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

train_op = ...  # Something that depends on `next_element`.

for _ in range(NUM_EPOCHS):
  # Initialize the iterator at the beginning of `dataset`.
  sess.run(iterator.initializer)

  # Loop over the examples in `iterator`, running `train_op`.
  try:
    while True:
      sess.run(train_op)

  except tf.errors.OutOfRangeError:  # Thrown at the end of the epoch.
    pass

  # Perform any per-epoch computations here.

有关不同类型的迭代器的更多详细信息,请参阅 the tf.data programmer's guide .

关于Tensorflow Dataset API 在完成一个 epoch 后恢复迭代器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49216946/

相关文章:

python - 如何在 Estimator 训练期间动态加载数据集的新部分?

tensorflow - 如何使用 TensorFlow 的数据集 API 多次迭代数据集?

python - as_list() 未在未知 TensorShape 上定义

python - Tensorflow:低级 LSTM 实现

python - tensorflow 2 : shape mismatch when serialize and decode it back

javascript - 如何在 Tensorflow.js 中保护(混淆/DRM)经过训练的模型权重?

python - 当其中有未知元素时,操纵张量形状的正确方法是什么?

python - 验证损失仅在第一个时期为零

python - 如何使用实验性保存和加载方法保存和使用 Tensorflow 数据集?

python - 如何将 2 个 keras 模型的输出连接到一个单独的层?