python - 来自生成器的数据集一次生成多个元素

标签 python tensorflow tensorflow-datasets

我正在测试是否需要从已弃用的基于队列的 API 迁移到 TensorFlow 中的数据集 API。

我似乎找不到等效项的一个用例是 tf.train.batchenqueue_many 参数。

我特别想创建一个可以生成“批量”数组的 Python 生成器,其中“批量大小”不一定与用于 SGD 训练更新的数组相同,然后对该数据流应用批处理(即与 tf.train.batch 中的 enqueue_many 一样)。

是否有任何解决方法可以在新的数据集 API 中实现此目的?

最佳答案

尝试使用平面 map

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
n_reads=10
read_batch_size=20
training_batch_size = 2

def mnist_gen():
    mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
    for i in range(n_reads):
        batch_x, batch_y = mnist.train.next_batch(read_batch_size)
        # Yielding a batch instead of single record
        yield batch_x,batch_y
data = tf.data.Dataset.from_generator(mnist_gen,output_types=(tf.float32,tf.float32))
data = data.flat_map(lambda *x: tf.data.Dataset.zip(tuple(map(tf.data.Dataset.from_tensor_slices,x)))).batch(training_batch_size)
# if u yield only batch_x change lambda function to data.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x)))
iter = data.make_one_shot_iterator()
next_item = iter.get_next()

X= next_item[0]
Y = next_item[1]

with tf.Session() as sess:
    for i in range(n_reads*read_batch_size // training_batch_size):
        print(i, sess.run(X))

关于python - 来自生成器的数据集一次生成多个元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53500684/

相关文章:

python - 使用自定义估算器控制纪元

python - 将整数列添加到 PySpark 数据框中的时间戳列

python - 类型错误:int 不可调用

python - Pandas :如何识别具有 dtype 对象但混合类型项目的列?

c++ - TensorFlow C++等于argmax(axis = -1)

machine-learning - TensorFlow:非单热向量的最佳方法?

tensorflow - 带 keras.utils.Sequence 对象或 tf.data.Dataset 的输入管道?

tensorflow - 在 TPU 上使用大型 tensorflow 数据集

python - 如何以正确的格式在文本文件中写入两个 numpy 数组?

python - Tensorflow:使用 num_parallel_calls 的数据集映射没有提供加速