将 .cache()
步骤添加到我的数据集管道时,连续的训练纪元仍然从网络存储下载数据。
我在网络存储上有一个数据集。我想缓存它,但不重复它:训练纪元必须遍历整个数据集。 这是我的数据集构建管道:
return tf.data.Dataset.list_files(
file_pattern
).interleave(
tf.data.TFRecordDataset,
num_parallel_calls=tf.data.experimental.AUTOTUNE
).shuffle(
buffer_size=2048
).batch(
batch_size=2048,
drop_remainder=True,
).cache(
).map(
map_func=_parse_example_batch,
num_parallel_calls=tf.data.experimental.AUTOTUNE
).prefetch(
buffer_size=32
)
如果我按原样使用它,则会在每个时期下载数据集。为了避免这种情况,我必须将 .repeat()
步骤添加到管道中,并使用 model.fit
函数的 steps_per_epoch
关键字。但是,我不知道完整数据集的大小,因此无法传递正确的 steps_per_epoch
值。
缓存和使用未知大小的数据集的正确方法是什么?
谢谢。
<小时/>编辑
在阅读一些 TF 代码时,我(重新)发现了 make_initializable_iterator
。看来这就是我正在寻找的,也就是说多次迭代同一数据集(在第一次迭代后利用缓存)。但是,这已被弃用,并且不再是 TF2 中主要 API 的一部分。
更新指令是使用for ... in dataset
手动迭代数据集。这不是 keras.Model.fit 函数所做的吗?我必须手动编写训练循环才能获得缓存优势吗?
善良。
最佳答案
在 TF2.0 中,您不需要 .repeat()
。通过
successives training epochs still download the data from the network storage.
我认为您对消息filling up shuffle buffer
感到困惑。如果您使用 shuffle()
函数,这种情况会在每个纪元之前发生。也许尝试不使用 shuffle()
,只是为了看看差异。
另外,我建议您在 map()
之后和 batch()
之前使用 cache()
。
编辑
filling up shuffle buffer
是使用shuffle
功能时收到的消息。使用cache()
后,您仍然可以shuffle()
数据集。看here
另外,如果我理解正确的话,您正在将 map()
的结果数据集提供给您的模型进行训练,那么您应该 cache()
这个数据集而不是另一个数据集,因为将就此进行培训。
要计算数据集中的元素数量,您可以使用以下代码
num_elements = 0
for element in dataset: # tf.dataset type
num_elements += 1
print ('Total number of elements in the file: ',num_elements)
现在,通过将这个 num_elements
与您的 batch_size
进行比较,您将得到 steps_per_epoch
关于python - 如何缓存和迭代未知大小的数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57977408/