现在我使用以下函数进行洗牌
from tensorflow.contrib import data
def input_pipeline(filenames, batch_size):
# Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data.
dataset = data.TextLineDataset(filenames)
dataset = dataset.map(decode_func)
dataset = dataset.shuffle(buffer_size=10000) # Equivalent to min_after_dequeue=10000.
dataset = dataset.batch(batch_size)
# Return an *initializable* iterator over the dataset, which will allow us to
# re-initialize it at the beginning of each epoch.
return dataset.make_initializable_iterator()
但它只会以 buffer_size
的量对数据进行混洗,并且会按顺序填充 buffer
。
我的数据量很大,我不能将 buffer_size 设置得太大。还有其他解决方案可以打乱整个数据集吗?
最佳答案
目前,数据集 API 不支持对整个数据集(超过 10k 个示例)进行混洗。根据this线程,常见的做法是:
- Randomly shuffle the entire data once using a MapReduce/Spark/Beam/etc. job to create a set of roughly equal-sized files ("shards").
In each epoch:
a. Randomly shuffle the list of shard filenames, using Dataset.list_files(...).shuffle(num_shards).
b. Use dataset.interleave(lambda filename: tf.data.TextLineDataset(filename), cycle_length=N) to mix together records from N different shards.
c. Use dataset.shuffle(B) to shuffle the resulting dataset. Setting B might require some experimentation, but you will probably want to set it to some value larger than the number of records in a single shard.
关于tensorflow - 如何使用 TensorFlow 打乱整个数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44792761/