tensorflow - tf.data WindowDataset flat_map 给出 'dict' 对象没有属性 'batch' 错误

标签 tensorflow tensorflow2.0 tf.data.dataset

我正在尝试执行 (batch_size, time_steps, my_data) 类型的批处理

为什么在 flat_map 步骤中我得到 AttributeError: 'dict' object has no attribute 'batch'

 x_train = np.random.normal(size=(60000, 768))
    token_type_ids = np.ones(shape=(len(x_train)))
    position_ids = np.random.normal(size=(x_train.shape[0], 5))

    features_ds = tf.data.Dataset.from_tensor_slices({'inputs_embeds': x_train,
                                                      'token_type_ids': token_type_ids,
                                                      'position_ids': position_ids})
    y_ds = tf.data.Dataset.from_tensor_slices(y_train)
    ds = tf.data.Dataset.zip((features_ds, y_ds))
    # result = list(ds.as_numpy_iterator())

    result_ds = ds.window(size=time_steps, shift=time_steps, stride=1, drop_remainder=True). \
        flat_map(lambda x, y: tf.data.Dataset.zip((x.batch(time_steps), y.batch(time_steps))))

知道问题出在哪里吗?以及如何解决?

最佳答案

您可以添加批处理作为单独的步骤:

x_train = np.random.normal(size=(60000, 768))
token_type_ids = np.ones(shape=(len(x_train)))
position_ids = np.random.normal(size=(x_train.shape[0], 5))

features_ds = tf.data.Dataset.from_tensor_slices({'inputs_embeds': x_train,
                                                  'token_type_ids': token_type_ids,
                                                  'position_ids': position_ids})
y_train = np.random.normal(size=(60000, 1))
y_ds = tf.data.Dataset.from_tensor_slices(y_train)
ds = tf.data.Dataset.zip((features_ds, y_ds))

result_ds = ds.window(size=time_steps, shift=time_steps, stride=1, drop_remainder=True).\
    flat_map(lambda x, y: tf.data.Dataset.zip((x, y)))

time_steps=3
result_ds=result_ds.batch(time_steps)

for i in result_ds.take(1):
    print(i)

关于tensorflow - tf.data WindowDataset flat_map 给出 'dict' 对象没有属性 'batch' 错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63366229/

相关文章:

python - TensorFlow 模型获得零损失

tensorflow - 如何获取tensorflow.data.Dataset的行数、列数/维度数?

python - 在 EagerTensor 中使用不同的数据类型

python - 将 TensorFlow 张量转换为 Numpy 数组

python - 在转换过程中从 tensorflow 对象中提取numpy值

python - TensorFlow 模型的 tf.data 管道中存在问题

python - 不能使用 TensorFlow 变量两次

python - 如何解读TensorFlow的卷积滤波器和stridding参数?

python - 如何提高数据输入管道性能?

python - 如何在 .map 函数中访问张量形状?