我有一个基于 .tfrecord 文件的 tensorflow 数据集。如何将数据集拆分为测试和训练数据集?例如。 70% 训练和 30% 测试?
编辑:
我的 Tensorflow 版本:1.8
我已经检查过,可能的副本中没有提到“split_v”函数。此外,我正在使用 tfrecord 文件。
最佳答案
您可以使用 Dataset.take()
和 Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)
为了更一般性,我举了一个使用 70/15/15 train/val/test split 的例子,但如果你不需要测试或验证集,只需忽略最后两行。
Take :
Creates a Dataset with at most count elements from this dataset.
Skip :
Creates a Dataset that skips count elements from this dataset.
您可能还想查看
Dataset.shard()
:Creates a Dataset that includes only 1/num_shards of this dataset.
关于tensorflow - 如何拆分 Tensorflow 数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51125266/