python - TensorFlow 数据集洗牌每个 Epoch

标签 python tensorflow

manual在 Tensorflow 中的 Dataset 类上,它展示了如何对数据进行混洗以及如何对其进行批处理。然而,如何对每个时期的数据进行洗牌并不明显。我已经尝试了下面的方法,但是数据在第二个纪元中以与第一个纪元完全相同的顺序给出。有人知道如何使用数据集在不同时期之间进行洗牌吗?

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)
for _ in range(4):
    print(sess.run(next_batch))

最佳答案

我的环境:Python 3.6,TensorFlow 1.4。

TensorFlow 已添加 Dataset进入 tf.data .
data.shuffle的位置需谨慎.在您的代码中,数据的纪元已放入 dataset在您之前的缓冲区 shuffle .这里有两个可用的例子来混洗数据集。

洗牌所有元素

# shuffle all elements
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

输出:
epoch 1
[1 4 5]
[3 0 7]
[6 9 8]
[10  2 11]
epoch 2
[2 0 6]
[1 7 4]
[5 3 8]
[11  9 10]

批次之间混洗,而不是批量混洗
# shuffle between batches, not shuffle in a batch
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
dataset = dataset.shuffle(buffer_size=buffer_size)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

输出:
epoch 1
[0 1 2]
[6 7 8]
[3 4 5]
[6 7 8]
epoch 2
[3 4 5]
[0 1 2]
[ 9 10 11]
[ 9 10 11]

关于python - TensorFlow 数据集洗牌每个 Epoch,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44124376/

相关文章:

tensorflow - 在 tensorflow 中实现空间更改器(mutator)网络会很直接吗?

python - collections.defaultdict 是如何工作的?

python - 使用 pyserial 将数据发送到串行 - 在 Razor 9DOF IMU 上工作

python - 函数返回 np.array 的副本并替换了一些元素

python - 使用 pandas 时间序列的过去 n 小时的变化率

tensorflow - 带有 train_on_batch 的优化器?

python - argparse 中 --default 和 --store_const 的区别

python - “DNN”对象在 ImageDataGenerator() 中没有属性 'fit_generator' - keras - python

python - tensorflow 变量名称中允许使用哪些字符?

c - 如何在 Tensorflow Lite(实验 C API)中创建输入张量并与解释器一起使用?