例如想象我使用 Librispeech dataset via TFDS (或任何数据集,包括不同长度数据的序列),然后使用 padded_batch
创建批处理,例如像这样:
import tensorflow_datasets as tfds
dataset = tfds.load(name="librispeech", split="train_clean100")
dataset = dataset.shuffle(1024)
dataset = dataset.padded_batch(32)
现在,当迭代结果数据集时,即在(填充的)批处理上,我如何知道填充批处理中的原始序列长度?或者此时此信息丢失了?我将如何扩展管道以包含它?是否有像 AddSeqLengthInfoDataset
这样的特殊数据集?这需要在padded_batch
之前运行,对吧?
(这基本上相当于 my question for TF PaddingFIFOQueue
,但对于 tf.data.Dataset。)
有例子吗? (我有点想知道我没有找到任何关于此的信息。我认为当您处理序列时这是一个相当标准的要求,您是否需要有关原始序列长度的信息?)
最佳答案
您可以向数据集中添加一个新字段来保存序列的大小,例如如下所示:
import tensorflow as tf
# Make a dataset with variable-size data
def generate_data():
for i in range(10):
yield {'id': i, 'data': range(i % 5)}
ds = tf.data.Dataset.from_generator(generate_data,
{'id': tf.int32, 'data': tf.int32},
{'id': [], 'data': [None]})
# Add field with size of data
ds = ds.map(lambda item: {**item, 'size': tf.shape(item['data'])[0]})
# Padded batch
ds = ds.padded_batch(3)
# Show dataset
for batch in ds:
tf.print(batch)
输出:
{'data': [[0 0]
[0 0]
[0 1]], 'id': [0 1 2], 'size': [0 1 2]}
{'data': [[0 1 2 0]
[0 1 2 3]
[0 0 0 0]], 'id': [3 4 5], 'size': [3 4 0]}
{'data': [[0 0 0]
[0 1 0]
[0 1 2]], 'id': [6 7 8], 'size': [1 2 3]}
{'data': [[0 1 2 3]], 'id': [9], 'size': [4]}
然后您可以使用例如 tf.sequence_mask
使用该字段的值来屏蔽填充值。
另一种选择是将一些特殊的 padding_values
传递给 padded_batch
不能出现在实际数据中,例如-1
或 nan
,但这取决于这些对于您的问题是否实际上是无效值。
关于tensorflow - 如何从 `tf.data.Dataset` 的填充批处理中获取序列长度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62087640/