我有 190 个特征和标签,我的批量大小为 20,但经过 9 次迭代 tf.reshape
返回异常 reshape 的输入是具有 21 个值的张量,但请求的形状有60,我知道这是由于 Iterator.get_next()
。我如何恢复我的迭代器,以便它再次从头开始提供批处理服务?
最佳答案
如果您想重新启动tf.data.Iterator
从其数据集
的开头,考虑使用可初始化迭代器,它有一个可以运行来重新初始化迭代器的操作:
dataset = ... # A `tf.data.Dataset` instance.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
train_op = ... # Something that depends on `next_element`.
for _ in range(NUM_EPOCHS):
# Initialize the iterator at the beginning of `dataset`.
sess.run(iterator.initializer)
# Loop over the examples in `iterator`, running `train_op`.
try:
while True:
sess.run(train_op)
except tf.errors.OutOfRangeError: # Thrown at the end of the epoch.
pass
# Perform any per-epoch computations here.
有关不同类型的迭代器的更多详细信息,请参阅 the tf.data
programmer's guide .
关于Tensorflow Dataset API 在完成一个 epoch 后恢复迭代器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49216946/