tensorflow - 如何从 `tf.data.Dataset` 的填充批处理中获取序列长度?

标签 tensorflow tensorflow-datasets

例如想象我使用 Librispeech dataset via TFDS (或任何数据集,包括不同长度数据的序列),然后使用 padded_batch创建批处理,例如像这样:

import tensorflow_datasets as tfds

dataset = tfds.load(name="librispeech", split="train_clean100")
dataset = dataset.shuffle(1024)
dataset = dataset.padded_batch(32)

现在,当迭代结果数据集时,即在(填充的)批处理上,我如何知道填充批处理中的原始序列长度?或者此时此信息丢失了?我将如何扩展管道以包含它?是否有像 AddSeqLengthInfoDataset 这样的特殊数据集?这需要在padded_batch之前运行,对吧? (这基本上相当于 my question for TF PaddingFIFOQueue,但对于 tf.data.Dataset。)

有例子吗? (我有点想知道我没有找到任何关于此的信息。我认为当您处理序列时这是一个相当标准的要求,您是否需要有关原始序列长度的信息?)

最佳答案

您可以向数据集中添加一个新字段来保存序列的大小,例如如下所示:

import tensorflow as tf

# Make a dataset with variable-size data
def generate_data():
    for i in range(10):
        yield {'id': i, 'data': range(i % 5)}
ds = tf.data.Dataset.from_generator(generate_data,
                                    {'id': tf.int32, 'data': tf.int32},
                                    {'id': [], 'data': [None]})
# Add field with size of data
ds = ds.map(lambda item: {**item, 'size': tf.shape(item['data'])[0]})
# Padded batch
ds = ds.padded_batch(3)
# Show dataset
for batch in ds:
    tf.print(batch)

输出:

{'data': [[0 0]
 [0 0]
 [0 1]], 'id': [0 1 2], 'size': [0 1 2]}
{'data': [[0 1 2 0]
 [0 1 2 3]
 [0 0 0 0]], 'id': [3 4 5], 'size': [3 4 0]}
{'data': [[0 0 0]
 [0 1 0]
 [0 1 2]], 'id': [6 7 8], 'size': [1 2 3]}
{'data': [[0 1 2 3]], 'id': [9], 'size': [4]}

然后您可以使用例如 tf.sequence_mask使用该字段的值来屏蔽填充值。

另一种选择是将一些特殊的 padding_values 传递给 padded_batch不能出现在实际数据中,例如-1nan,但这取决于这些对于您的问题是否实际上是无效值。

关于tensorflow - 如何从 `tf.data.Dataset` 的填充批处理中获取序列长度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62087640/

相关文章:

python - KERAS 在 docker 容器中添加第一层时随机卡住

python - Tensorflow 没有正式命名的模块

python - "TypeError: ' Tensor ' object is not iterable"错误与tensorflow Estimator

python - 如何在 TensorFlow Recommenders 库中使用自定义 .csv 数据集?

python - tensorflow如何使用多个cpu

python - tensorflow (或 numpy)中特定维度的矩阵乘法

Python 成员资格运算符 "In"TensorFlow 数据集

tensorflow - 如何使tf.data.Dataset在一次调用中返回所有元素?

python - 如何使用 TensorFlow tf.data.Dataset flat_map 生成派生数据集?

tensorflow - CNN 损失停留在 2.302 (ln(10))