tensorflow 数据集 shuffle 然后批处理或批处理然后 shuffle

标签 tensorflow tensorflow-datasets

我最近开始学习 tensorflow。

我不确定是否有区别

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.shuffle(buffer_size=4)
ds.batch(4)


x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.batch(4)
ds.shuffle(buffer_size=4)

另外,我不确定为什么我不能使用
dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)

因为它给出了错误
dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
AttributeError: 'TensorSliceDataset' object has no attribute 'shuffle_batch'

谢谢!

最佳答案

TL;DR: 是,有一点不同。几乎总是,您会想调用 Dataset.shuffle() 之前 Dataset.batch() .没有shuffle_batch() tf.data.Dataset 上的方法类,并且您必须分别调用这两个方法来对数据集进行混洗和批处理。
tf.data.Dataset 的转换以与调用它们相同的顺序应用。 Dataset.batch()将其输入的连续元素组合成输出中的单个批处理元素。
通过考虑以下两个数据集,我们可以看到操作顺序的效果:

tf.enable_eager_execution()  # To simplify the example code.

# Batch before shuffle.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.batch(3)
dataset = dataset.shuffle(9)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([1 1 1], shape=(3,), dtype=int32)
# tf.Tensor([2 2 2], shape=(3,), dtype=int32)
# tf.Tensor([0 0 0], shape=(3,), dtype=int32)

# Shuffle before batch.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.shuffle(9)
dataset = dataset.batch(3)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([2 0 2], shape=(3,), dtype=int32)
# tf.Tensor([2 1 0], shape=(3,), dtype=int32)
# tf.Tensor([0 1 1], shape=(3,), dtype=int32)

在第一个版本中(洗牌前的批处理),每批处理的元素是输入中的 3 个连续元素;而在第二个版本中(批处理前洗牌),它们是从输入中随机采样的。通常,当通过(某些变体)小批量 stochastic gradient descent 进行训练时,每个批处理的元素应该从总输入中尽可能均匀地采样。否则,网络可能会过度拟合输入数据中的任何结构,并且生成的网络将无法达到如此高的精度。

关于tensorflow 数据集 shuffle 然后批处理或批处理然后 shuffle,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50437234/

相关文章:

python - 使用 ELMo 嵌入段落

python - 由于保存模型导致训练崩溃 : "tensorflow.GraphDef was modified concurrently during serialization"

python - 将数据从 `tf.data.Dataset` 分发给多个工作人员(例如 Horovod)

python - 如何访问 tf.data.Dataset.list_files() 收集的文件名?

python - 在 Python 中用 tensorflow 用权重填充张量

tensorflow - 杀死tensorflow实例后如何获取 "reset"张量板数据

tensorflow - 在 TensorFlow 中处理多个图表

python - 如何修复 "module ' tensorflow' has no attribute 'estimator' "错误

python - TensorFlow 1.7 + Keras 和数据集 : Object has no attribute 'ndim'

python - 在 tensorflow.data.Dataset.map() 函数中完成的操作梯度