python - 如何缓存和迭代未知大小的数据集?

标签 python python-3.x tensorflow tensorflow-datasets tensorflow2.0

.cache() 步骤添加到我的数据集管道时,连续的训练纪元仍然从网络存储下载数据。

我在网络存储上有一个数据集。我想缓存它,但不重复它:训练纪元必须遍历整个数据集。 这是我的数据集构建管道:

return tf.data.Dataset.list_files(
        file_pattern
    ).interleave(
        tf.data.TFRecordDataset,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    ).shuffle(
        buffer_size=2048
    ).batch(
        batch_size=2048,
        drop_remainder=True,
    ).cache(
    ).map(
        map_func=_parse_example_batch,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    ).prefetch(
        buffer_size=32
    )

如果我按原样使用它,则会在每个时期下载数据集。为了避免这种情况,我必须将 .repeat() 步骤添加到管道中,并使用 model.fit 函数的 steps_per_epoch 关键字。但是,我不知道完整数据集的大小,因此无法传递正确的 steps_per_epoch 值。

缓存和使用未知大小的数据集的正确方法是什么?

谢谢。

<小时/>

编辑

在阅读一些 TF 代码时,我(重新)发现了 make_initializable_iterator 。看来这就是我正在寻找的,也就是说多次迭代同一数据集(在第一次迭代后利用缓存)。但是,这已被弃用,并且不再是 TF2 中主要 API 的一部分。

更新指令是使用for ... in dataset手动迭代数据集。这不是 keras.Model.fit 函数所做的吗?我必须手动编写训练循环才能获得缓存优势吗?

善良。

最佳答案

在 TF2.0 中,您不需要 .repeat()。通过

successives training epochs still download the data from the network storage.

我认为您对消息filling up shuffle buffer感到困惑。如果您使用 shuffle() 函数,这种情况会在每个纪元之前发生。也许尝试不使用 shuffle(),只是为了看看差异。 另外,我建议您在 map() 之后和 batch() 之前使用 cache()

编辑

filling up shuffle buffer

是使用shuffle功能时收到的消息。使用cache()后,您仍然可以shuffle()数据集。看here 另外,如果我理解正确的话,您正在将 map() 的结果数据集提供给您的模型进行训练,那么您应该 cache() 这个数据集而不是另一个数据集,因为将就此进行培训。 要计算数据集中的元素数量,您可以使用以下代码

num_elements = 0
for element in dataset: # tf.dataset type
  num_elements += 1
print ('Total number of elements in the file: ',num_elements)

现在,通过将这个 num_elements 与您的 batch_size 进行比较,您将得到 steps_per_epoch

关于python - 如何缓存和迭代未知大小的数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57977408/

相关文章:

Python 绘图库

python,如何将 pandas 系列转换为 pandas DataFrame?

python - 使用 Python 根据文件内容编写 splunk 查询

python - 需要迭代行以检查条件并在满足条件时从不同列检索值

machine-learning - keras的输入层可以接受自定义输入吗?

python - 当导数未知并且需要一批输出来计算成本时,如何训练模型?

python - Pandas Dataframe 创建一个独特的列

python - 如何使fastapi中的查询参数接受两种类型的输入?

python - Tensorflow - 保存模型

python - 使用非零初始值创建大小为 N 的 bytes() 的大多数 pythonic 方法?