假设我有3个tfrecord文件,分别是neg.tfrecord
、pos1.tfrecord
、pos2.tfrecord
。
我用
dataset = tf.data.TFRecordDataset(tfrecord_file)
此代码创建 3 个数据集对象。
我的batch size是400,其中包括200个neg数据、100个pos1数据和100个pos2数据。如何获得所需的数据集?
我将在 keras.fit()(Eager Execution)中使用此数据集对象。
我的tensorflow版本是1.13.1。
之前,我尝试获取每个数据集的迭代器,获取数据后再手动concat,但效率低下,GPU利用率不高。
最佳答案
您可以使用交错
filenames = [tfrecord_file1, tfrecord_file2]
dataset = (Dataset.from_tensor_slices(filenames).interleave(lambda x:TFRecordDataset(x)
dataset = dataset.map(parse_fn)
...
或者您甚至可以尝试并行交错。请参阅https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#interleave https://www.tensorflow.org/api_docs/python/tf/data/experimental/parallel_interleave
关于python - 如何将多个数据集合并为一个数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55154836/