python - 在 2.0 session 中迭代 tf.data.Dataset 的正确方法

标签 python tensorflow tensorflow-datasets tfrecord

我已经从youtube-8m project下载了一些*.tfrecord数据。您可以使用以下命令下载“一小部分”数据:

curl data.yt8m.org/download.py | shard=1,100 分区=2/video/train 镜像=us python

我正在尝试了解如何使用新的 tf.data API。我想熟悉人们迭代数据集的典型方式。我一直在使用 TF 网站上的指南和这张幻灯片:Derek Murray's Slides

这是我定义数据集的方式:

# Use interleave() and prefetch() to read many files concurrently.
files = tf.data.Dataset.list_files("./youtube_vids/*.tfrecord")
dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(100),
                           cycle_length=8)

# Use num_parallel_calls to parallelize map().
dataset = dataset.map(lambda record: tf.parse_single_example(record, feature_map),
                     num_parallel_calls=2) #

# put in x,y output form
dataset = dataset.map(lambda x: (x['mean_rgb'], x['id']))

# shuffle
dataset = dataset.shuffle(10000)

#one epoch
dataset = dataset.repeat(1)
dataset = dataset.batch(200)

#Use prefetch() to overlap the producer and consumer.
dataset = dataset.prefetch(10)

现在,我知道在急切执行模式下我可以

for x,y in dataset:
    x,y

但是,当我尝试按如下方式创建迭代器时:

# A one-shot iterator automatically initializes itself on first use.
iterator = dset.make_one_shot_iterator()

# The return value of get_next() matches the dataset element type.
images, labels = iterator.get_next()

并使用 session 运行

with tf.Session() as sess:

    # Loop until all elements have been consumed.
    try:
        while True:
            r = sess.run(images)
    except tf.errors.OutOfRangeError:
        pass

我收到警告

Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.

所以,这是我的问题:

在 session 中迭代数据集的正确方法是什么?这只是 v1 和 v2 差异的问题吗?

此外,将数据集直接传递给估计器的建议意味着输入函数还具有一个迭代器,如上面 Derek Murray 的幻灯片中所定义,对吗?

最佳答案

对于 Estimator API,您不必指定迭代器,只需将数据集对象作为输入函数传递即可。

def input_fn(filename):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.shuffle().repeat()
    dataset = dataset.map(parse_func)
    dataset = dataset.batch()
    return dataset

estimator.train(input_fn=lambda: input_fn())

在 TF 2.0 数据集变得可迭代,因此,正如警告消息所示,您可以使用

for x,y in dataset:
    x,y

关于python - 在 2.0 session 中迭代 tf.data.Dataset 的正确方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56399919/

相关文章:

python - 基于一维唯一元素过滤numpy结构化数组

tensorflow - tf.data.Dataset 中大量数据集的最佳数据流和处理解决方案

tensorflow - 将 Tensorflow 数据集 API 创建的数据集拆分为训练和测试?

python - 从套接字 : Is it guaranteed to at least get x bytes? 读取

python - python 保护文件的最佳方法?

python-3.x - Keras 方法 'predict' 和 'predict_generator' 具有不同的结果

tensorflow - tensorflow 中是否有用于目标检测的 ROI 池化层的计划?

python - 从 HDFS、TFRecordDataset+num_parallel_read 等远程主机读取时哪个更好?或 parallel_interleave

python - 压缩属于同月的文件

python - 使用 CUDA GeForce9600GT 在 Ubuntu 服务器上执行 Tensorflow