tensorflow - tf.data group_by_window() 无需首先迭代完整的数据集

标签 tensorflow nlp tensor tensorflow-datasets

具有固定大小的 example_idsfeature_frames 的数据集如下所示(即属于相同 example_id 的所有帧都是连续的):

id | feature_frame
------------------
 0 | [0, 1, 2, 3]
 0 | [4, 3, 2, 1]
 1 | [3, 1, 0, 0]
 2 | [7, 7, 1, 2]
 2 | [2, 7, 1, 2]
 2 | [4, 7, 1, 3]

我想将具有相同 example_id 的所有连续帧批处理到单个张量,即

id| features_batch
------------------
0 | [[0,1,2,3],[4,3,2,1]]
1 | [[3,1,0,0]]
2 | [[7,7,1,2],[2,7,1,2],[4,7,1,3]]

而且我想这样做而不必先遍历整个数据集。 我以为我可以使用 tf.data.experimental.group_by_reducer,但它必须迭代整个数据集,因为它“不知道”一个帧是否带有 example_id 0 将出现在数据集的末尾。

然后,我虽然,好吧,让我们使用 group_by_window - 但令我惊讶的是它还必须迭代完整的数据集,然后再输出一些东西(在 max_len=None 中使用 code>test_reducer() 在两种情况之间切换)。

所以我的问题是 - 有什么方法可以实现这一点(减少所有具有相同 ID 的帧,即 tf.signal.frame 的对立面)?

想象一下,有很多长序列,然后将它们框定为固定大小的窗口,对这些窗口进行一些预测,然后......我们如何将窗口的预测减少为示例的预测。如果我像在下面的代码中那样做,我想如果数据集足够大,在某个时候我会得到一个 OOM,因为在开始减少之前,reducer 似乎首先从完整的数据集中收集所有帧。

class FramesDS:

    def example_frames_ds(self, max_len=None):
        def gen():
            example_count = 0
            while True:
                example_id = np.random.randint(low=np.iinfo(np.int64).min,
                                               high=np.iinfo(np.int64).max, dtype='int64')
                frame_count = np.random.randint(low=2, high=10)
                print("{:4d}| {}: frame_count:{}".format(example_count, example_id, frame_count))
                example_count += 1

                max_seq_len = 8
                frames  = np.random.randint(low=0, high=9, size=(frame_count, max_seq_len))
                example_ids = [example_id] * frame_count
                yield example_ids, frames
        ds = tf.data.Dataset.from_generator(gen,
                                            output_types=(tf.int64, tf.int32),
                                            output_shapes=(tf.TensorShape([None, ]), tf.TensorShape([None, None])))
        if max_len is not None:
            ds = ds.take(max_len)
        return ds


class ReducerTestCase(unittest.TestCase):
    def test_reducer(self):

        max_len=100    ### SET to None - to check that no output is generated

        ds = FramesDS().example_frames_ds(max_len)

        def key_fn(example_id, frame):
            return example_id

        def init_fn(example_id):
            return example_id, tf.zeros([0,], dtype=tf.int32)

        def reduce_fn(state, rinput):
            state_eid, frames = state
            example_id, frame = rinput
            tf.assert_equal(state_eid, example_id)
            frames = tf.concat([tf.reshape(frames, (tf.shape(frames)[0],
                                                    tf.shape(frame)[-1])),
                                tf.expand_dims(frame, axis=0)], axis=0)
            return example_id, frames

        def fin_fn(example_id, frames):
            return example_id, frames

        reducer = tf.data.experimental.Reducer(init_func=init_fn,
                                               reduce_func=reduce_fn,
                                               finalize_func=fin_fn)

        ds = ds.unbatch().batch(8)
        ds = ds.unbatch()

        def window_reduce_fn(key, ds):
            ds = ds.apply(tf.data.experimental.group_by_reducer(key_func=key_fn, reducer=reducer))
            return ds

        ds = ds.apply(tf.data.experimental.group_by_window(key_func=key_fn,
                                                           reduce_func=window_reduce_fn,
                                                           window_size=20))

        for example_id, frames in tqdm(ds):
            print("{}: {}".format(example_id, frames.shape))


最佳答案

您可以使用 tf.data.experimental.scan 完成此操作:

import tensorflow as tf
tf.enable_v2_behavior()

ds = tf.data.Dataset.from_tensor_slices(([0, 0, 1, 2, 2, 2], 
    [[0, 1, 2, 3], 
     [4, 3, 2, 1], 
     [3, 1, 0, 0], 
     [7, 7, 1, 2], 
     [2, 7, 1, 2], 
     [4, 7, 1, 3]]))

empty_val = tf.zeros((0,), tf.int32)

def scan_fn(state, elem):
  if elem[0] == state[0]:
    vals = tf.concat([state[1], elem[1]], 0)
    return ((elem[0], vals), empty_val)
  else:
    return ((elem[0], elem[1]), state[1])

# Append an empty element so that scan produces the final group.
ds = ds.concatenate(tf.data.Dataset.from_tensors((-1, empty_val)))
ds = ds.apply(tf.data.experimental.scan((-1, empty_val), scan_fn))
ds = ds.filter(lambda x: len(x) > 0)
for elem in ds:
  print(elem)

# Prints
# tf.Tensor([0 1 2 3 4 3 2 1], shape=(8,), dtype=int32)
# tf.Tensor([3 1 0 0], shape=(4,), dtype=int32) 
# tf.Tensor([7 7 1 2 2 7 1 2 4 7 1 3], shape=(12,), dtype=int32)

关于tensorflow - tf.data group_by_window() 无需首先迭代完整的数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59216451/

相关文章:

matrix - PyTorch 向量/矩阵/张量的逐元素乘积

Tensorflow:如何处理多个输入

python - Keras神经网络: ValueError - input shape is wrong

c++ - 使用 TensorFlow 训练模型和 C API 进行预测

python - 如何使用 nltk.stem.snowball 阻止 Shakespere/KJV

自然语言处理 : Is Gazetteer a cheat

image - Tensorflow - 洗牌和分割图像和标签的数据集

tensorflow - 在 Keras 中使用多个 TPU 和 TF

haskell - Haskell 是否有任何类型的统计自然语言处理库?

python - 一次迭代两个 Pytorch 张量?