python - 监督数据集上的 TF2 padded_batch

标签 python tensorflow tensorflow-datasets tensorflow2.x

问题设置

我正在关注 this tutorial .本教程首先加载一个监督数据集(使用 tfds.loadas_supervised=True):

(train_data, test_data), info = tfds.load(
    'imdb_reviews/subwords8k', 
    split = (tfds.Split.TRAIN, tfds.Split.TEST), 
    with_info=True, as_supervised=True)

然后教程建议像这样打乱和填充数据集:

Movie reviews can be different lengths. We will use the padded_batch method to standardize the lengths of the reviews.

train_batches = train_data.shuffle(1000).padded_batch(10)
test_batches = test_data.shuffle(1000).padded_batch(10)

...但不幸的是 padded_batch方法需要教程似乎忘记的额外参数:

Traceback (most recent call last):
  File "imdb_reviews.py", line 14, in <module>
    train_batches = train_data.shuffle(1000).padded_batch(10)
TypeError: padded_batch() missing 1 required positional argument: 'padded_shapes'

重要假设

虽然错误堆栈表明 padded_shapes 是缺少的参数,但从教程中我认为可以公平地推断缺少的参数实际上是 batch_size(应该先于padded_shapes).

我尝试过的

我认为这可能很容易修复:

batch_sz = 100 # arbitrary number
train_batches = train_data.shuffle(1000).padded_batch(batch_sz, ([10],[None]))
test_batches = test_data.shuffle(1000).padded_batch(batch_sz, ([10],[None]))

...但我的解决方案显然是错误的:

Traceback (most recent call last):
  File "imdb_reviews.py", line 15, in <module>
    train_batches = train_data.shuffle(1000).padded_batch(batch_sz, ([10],[None]))
  File "/home/ggiuffre/.local/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 2298, in padded_batch
    batch_size, padded_shapes, padding_values, drop_remainder))
  File "/home/ggiuffre/.local/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 1481, in padded_batch
    drop_remainder)
  File "/home/ggiuffre/.local/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3813, in __init__
    _padded_shape_to_tensor(padded_shape, input_component_shape))
  File "/home/ggiuffre/.local/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3741, in _padded_shape_to_tensor
    % (padded_shape_as_shape, input_component_shape))
ValueError: The padded shape (None,) is not compatible with the corresponding input component shape ().

None 替换为 () 会得到 ValueError: Padded shape [()] must be a 1-D tensor of tf.int64 values,但它的形状为 (1, 0)。

None 替换为 1 会得到 ValueError: The padded shape (1,) is not compatible with the corresponding input component shape ().

问题

我应该给 padded_shapes 参数什么值?或者,更一般地说,我在这里做错了什么?

非常感谢您的帮助。

最佳答案

看看这个博客。 https://medium.com/@a.ydobon/tensorflow-2-0-word-embeddings-part3-964b2b9caf94

推荐

padded_shapes = ([None],())
train_batches = train_data.shuffle(1000).padded_batch(10,padded_shapes=padded_shapes)
test_batches = test_data.shuffle(1000).padded_batch(10,padded_shapes=padded_shapes)

这对我有用。

关于python - 监督数据集上的 TF2 padded_batch,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60211521/

相关文章:

python - 如何使用 TensorFlow 2.0 打乱两个 numpy 数据集?

python - 使用 zip() 函数的 Python 列表理解的复杂度为 O(n)

machine-learning - 从 tensorflow 模型检查点提取权重值

tensorflow - tf.data 或 tf.keras.utils.Sequence。提高 tf.data 的效率?

python - Tensorflow dataset.batch() 不显示真实的批量大小

python - Python 中如何存储字符和整数

python - 列表中的项目数按降序排列

tensorflow - 如何在 tensorflow 中使用 tf.string_split() ?

python - 如何减少 Tensorflow 的 Object Detection API 中的训练步骤?

python - 建议在 tensorflow 2.0 中调试 `tf.data.Dataset` 操作