python-3.x - 使用 tensorflow 训练 CNN 时如何修复 'OutOfRangeError: End of sequence' 错误?

标签 python-3.x tensorflow tensorflow-datasets

我正在尝试使用我自己的数据集训练 CNN。我一直在使用 tfrecord 文件和 tf.data.TFRecordDataset API 来处理我的数据集。它适用于我的训练数据集。但是当我尝试对我的验证数据集进行批处理时,出现了“OutOfRangeError: End of sequence”的错误。上网浏览后,我以为是验证集的batch size问题,我一开始设置为32。但是在我将其更改为 2 之后,代码运行了大约 9 个 epoch,并且错误再次出现。

我使用输入函数来处理数据集,代码如下:

def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
    if is_training:
        dataset = dataset.shuffle(buffer_size=1500)
    dataset = dataset.map(parse_record)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)

    iterator = dataset.make_one_shot_iterator()

    features, labels = iterator.get_next()

    return features, labels

对于训练集,“batch_size”设置为 128,“num_epochs”设置为 None,这意味着无限重复。对于验证集,“batch_size”设置为 32(后来设置为 2,仍然无效),“num_epochs”设置为 1,因为我只想通过验证集一次。
我可以保证验证集包含足够的 epoch 数据。因为我已经尝试了下面的代码并且没有引发任何错误:
with tf.Session() as sess:
    features, labels = input_fn(False, valid_list, 32, 1, 1)
    for i in range(450):
        sess.run([features, labels])
        print(labels.shape)

在上面的代码中,当我将数字 450 更改为 500 或任何更大的值时,它会引发“OutOfRangeError”。这可以确认我的验证数据集包含足够 450 次迭代的数据,批量大小为 32。

我尝试对验证集使用较小的批量大小(即 2),但仍然有相同的错误。
我可以在 input_fn 中将“num_epochs”设置为“None”来运行代码以进行验证,但这似乎不是验证的工作方式。请问有什么帮助吗?

最佳答案

这种行为很正常。从 Tensorflow 文档:

If the iterator reaches the end of the dataset, executing the Iterator.get_next() operation will raise a tf.errors.OutOfRangeError. After this point the iterator will be in an unusable state, and you must initialize it again if you want to use it further.



设置dataset.repeat(None)时不报错的原因是因为数据集永远不会耗尽,因为它会无限重复。

要解决您的问题,您应该将代码更改为:
n_steps = 450
...    

with tf.Session() as sess:
    # Training
    features, labels = input_fn(True, training_list, 32, 1, 1)

    for step in range(n_steps):
        sess.run([features, labels])
        ...
    ...
    # Validation
    features, labels = input_fn(False, valid_list, 32, 1, 1)
    try:
        sess.run([features, labels])
        ...
    except tf.errors.OutOfRangeError:
        print("End of dataset")  # ==> "End of dataset"

您还可以对 input_fn 进行一些更改以在每个时期运行评估:
def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
    if is_training:
        dataset = dataset.shuffle(buffer_size=1500)
    dataset = dataset.map(parse_record)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)

    iterator = dataset.make_initializable_iterator()
    return iterator

n_epochs = 10
freq_eval = 1

training_iterator = input_fn(True, training_list, 32, 1, 1)
training_features, training_labels = training_iterator.get_next()

val_iterator = input_fn(False, valid_list, 32, 1, 1)
val_features, val_labels = val_iterator.get_next()

with tf.Session() as sess:
    # Training
    sess.run(training_iterator.initializer)
    for epoch in range(n_epochs):
        try:
            sess.run([training_features, training_labels])
        except tf.errors.OutOfRangeError:
            pass

        # Validation
        if (epoch+1) % freq_eval == 0:
            sess.run(val_iterator.initializer)
            try:
                sess.run([val_features, val_labels])
            except tf.errors.OutOfRangeError:
                pass

我建议你仔细看看this official guide如果您想更好地了解幕后发生的事情。

关于python-3.x - 使用 tensorflow 训练 CNN 时如何修复 'OutOfRangeError: End of sequence' 错误?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53930242/

相关文章:

python-3.x - 导入错误 : The HttpLocust class has been renamed to HttpUser in version 1. 0

python - 如何在整个程序的生命周期中唯一标识Python中的类

Python 3 - 多个字典键的交集

python - 如何在不将值保存到磁盘的情况下将张量恢复到过去的值?

python - 在 AWS Lambda 上运行 Tensorflow 2 预测

python - 使用请求登录网站

python - 如何使用越来越少的元素/值来平铺张量?

tensorflow - tf.data.Dataset 是否支持生成字典结构?

python - 我如何在 TensorFlow 中使用我自己的图像?

python - Tensorflow Dataset API 是否完全摆脱 feed_dict 参数?