tensorflow - 如何拆分 Tensorflow 数据集?

标签 tensorflow tensorflow-datasets

我有一个基于 .tfrecord 文件的 tensorflow 数据集。如何将数据集拆分为测试和训练数据集?例如。 70% 训练和 30% 测试?

编辑:

我的 Tensorflow 版本:1.8
我已经检查过,可能的副本中没有提到“split_v”函数。此外,我正在使用 tfrecord 文件。

最佳答案

您可以使用 Dataset.take()Dataset.skip() :

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)

为了更一般性,我举了一个使用 70/15/15 train/val/test split 的例子,但如果你不需要测试或验证集,只需忽略最后两行。

Take :

Creates a Dataset with at most count elements from this dataset.



Skip :

Creates a Dataset that skips count elements from this dataset.



您可能还想查看 Dataset.shard() :

Creates a Dataset that includes only 1/num_shards of this dataset.

关于tensorflow - 如何拆分 Tensorflow 数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51125266/

相关文章:

csv - TensorFlow - 给定 5 个默认值时,decode_csv : Expect 3 fields but have 5 in record 0, 会抛出:预期有 5 个字段,但记录 0 中有 3 个字段

python - 为什么在准确率保持不变的情况下损失会减少?

python - 从 numpy 和 scipy.sparse 准备 tensorflow 的数据输入

python - 将 Tensorflow 数据集转换为 2 个包含图像和标签的数组

python - 在联邦学习中将数据拆分为训练和测试

python - 从张量中随机移除

python - Tensorflow 占位符 vs Tensorflow 常量 vs Numpy 数组

python - Tensorflow frozen graph protobuf不预测使用c api

tensorflow - 在 TPU 上使用大型 tensorflow 数据集

python - 如何将窗口化数据集馈送到 Tensorflow 中的 StringLookup 层